wip -rebase
parent
ab3ac3f25b
commit
bc3d4eb403
|
@ -95,11 +95,11 @@ class ClassAttn(nn.Module):
|
|||
q = q * self.scale
|
||||
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
attn = torch.matmul(q, k.transpose(-2, -1))
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)
|
||||
x_cls = torch.matmul(attn, v).transpose(1, 2).reshape(B, 1, C)
|
||||
x_cls = self.proj(x_cls)
|
||||
x_cls = self.proj_drop(x_cls)
|
||||
|
||||
|
@ -158,7 +158,7 @@ class TalkingHeadAttn(nn.Module):
|
|||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
attn = torch.matmul(q, k.transpose(-2, -1))
|
||||
|
||||
attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
|
@ -167,7 +167,7 @@ class TalkingHeadAttn(nn.Module):
|
|||
attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
|
|
@ -105,7 +105,7 @@ class ConvRelPosEnc(nn.Module):
|
|||
def forward(self, q, v, size: Tuple[int, int]):
|
||||
B, h, N, Ch = q.shape
|
||||
H, W = size
|
||||
assert N == 1 + H * W
|
||||
torch._assert(N == 1 + H * W, '')
|
||||
|
||||
# Convolutional relative position encoding.
|
||||
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
|
||||
|
@ -149,8 +149,8 @@ class FactorAtt_ConvRelPosEnc(nn.Module):
|
|||
|
||||
# Factorized attention.
|
||||
k_softmax = k.softmax(dim=2)
|
||||
factor_att = k_softmax.transpose(-1, -2) @ v
|
||||
factor_att = q @ factor_att
|
||||
factor_att = torch.matmul(k_softmax.transpose(-1, -2), v)
|
||||
factor_att = torch.matmul(q, factor_att)
|
||||
|
||||
# Convolutional relative position encoding.
|
||||
crpe = self.crpe(q, v, size=size) # [B, h, N, Ch]
|
||||
|
@ -177,7 +177,7 @@ class ConvPosEnc(nn.Module):
|
|||
def forward(self, x, size: Tuple[int, int]):
|
||||
B, N, C = x.shape
|
||||
H, W = size
|
||||
assert N == 1 + H * W
|
||||
torch._assert(N == 1 + H * W, '')
|
||||
|
||||
# Extract CLS token and image tokens.
|
||||
cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
|
||||
|
@ -275,7 +275,7 @@ class ParallelBlock(nn.Module):
|
|||
""" Feature map interpolation. """
|
||||
B, N, C = x.shape
|
||||
H, W = size
|
||||
assert N == 1 + H * W
|
||||
torch._assert(N == 1 + H * W, '')
|
||||
|
||||
cls_token = x[:, :1, :]
|
||||
img_tokens = x[:, 1:, :]
|
||||
|
|
|
@ -30,6 +30,7 @@ from .helpers import build_model_with_cfg
|
|||
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
|
||||
from .registry import register_model
|
||||
from .vision_transformer_hybrid import HybridEmbed
|
||||
from .fx_features import register_leaf_module
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -56,6 +57,7 @@ default_cfgs = {
|
|||
}
|
||||
|
||||
|
||||
@register_leaf_module # FX can't symbolically trace control flow in forward method
|
||||
class GPSA(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
|
||||
locality_strength=1.):
|
||||
|
@ -82,7 +84,7 @@ class GPSA(nn.Module):
|
|||
self.rel_indices = self.get_rel_indices(N)
|
||||
attn = self.get_attention(x)
|
||||
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
@ -93,7 +95,7 @@ class GPSA(nn.Module):
|
|||
q, k = qk[0], qk[1]
|
||||
pos_score = self.rel_indices.expand(B, -1, -1, -1)
|
||||
pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
|
||||
patch_score = (q @ k.transpose(-2, -1)) * self.scale
|
||||
patch_score = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
patch_score = patch_score.softmax(dim=-1)
|
||||
pos_score = pos_score.softmax(dim=-1)
|
||||
|
||||
|
@ -178,11 +180,11 @@ class MHSA(nn.Module):
|
|||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
|
|
@ -22,6 +22,7 @@ import torch.nn.functional as F
|
|||
|
||||
from .helpers import to_2tuple, make_divisible
|
||||
from .weight_init import trunc_normal_
|
||||
from timm.models.fx_helpers import fx_and
|
||||
|
||||
|
||||
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||
|
@ -36,7 +37,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
|||
permute_mask: permute output dim according to this
|
||||
"""
|
||||
B, H, W, dim = q.shape
|
||||
x = (q @ rel_k.transpose(-1, -2))
|
||||
x = torch.matmul(q, rel_k.transpose(-1, -2))
|
||||
x = x.reshape(-1, W, 2 * W -1)
|
||||
|
||||
# pad to shift from relative to absolute indexing
|
||||
|
@ -133,8 +134,8 @@ class BottleneckAttn(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.pos_embed.height
|
||||
assert W == self.pos_embed.width
|
||||
torch._assert(H == self.pos_embed.height, '')
|
||||
torch._assert(W == self.pos_embed.width, '')
|
||||
|
||||
x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
|
||||
|
||||
|
@ -154,5 +155,3 @@ class BottleneckAttn(nn.Module):
|
|||
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
|
||||
out = self.pool(out)
|
||||
return out
|
||||
|
||||
|
||||
|
|
|
@ -72,9 +72,9 @@ class EvoNormSample2d(nn.Module):
|
|||
nn.init.ones_(self.v)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.dim() == 4, 'expected 4D input'
|
||||
torch._assert(x.dim() == 4, 'expected 4D input')
|
||||
B, C, H, W = x.shape
|
||||
assert C % self.groups == 0
|
||||
torch._assert(C % self.groups == 0, '')
|
||||
if self.apply_act:
|
||||
n = x * (x * self.v).sigmoid()
|
||||
x = x.reshape(B, self.groups, -1)
|
||||
|
|
|
@ -7,6 +7,7 @@ Official code consulted as reference: https://github.com/xvjiarui/GCNet
|
|||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
@ -52,7 +53,7 @@ class GlobalContext(nn.Module):
|
|||
if self.conv_attn is not None:
|
||||
attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
|
||||
attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
|
||||
context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
|
||||
context = torch.matmul(x.reshape(B, C, H * W).unsqueeze(1), attn)
|
||||
context = context.view(B, C, 1, 1)
|
||||
else:
|
||||
context = x.mean(dim=(2, 3), keepdim=True)
|
||||
|
|
|
@ -24,6 +24,7 @@ import torch.nn.functional as F
|
|||
|
||||
from .helpers import make_divisible
|
||||
from .weight_init import trunc_normal_
|
||||
from timm.models.fx_helpers import fx_and
|
||||
|
||||
|
||||
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
||||
|
@ -41,7 +42,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
|||
rel_size = rel_k.shape[0]
|
||||
win_size = (rel_size + 1) // 2
|
||||
|
||||
x = (q @ rel_k.transpose(-1, -2))
|
||||
x = torch.matmul(q, rel_k.transpose(-1, -2))
|
||||
x = x.reshape(-1, W, rel_size)
|
||||
|
||||
# pad to shift from relative to absolute indexing
|
||||
|
@ -167,8 +168,8 @@ class HaloAttn(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
assert H % self.block_size == 0
|
||||
assert W % self.block_size == 0
|
||||
torch._assert(H % self.block_size == 0, '')
|
||||
torch._assert(W % self.block_size == 0, '')
|
||||
num_h_blocks = H // self.block_size
|
||||
num_w_blocks = W // self.block_size
|
||||
num_blocks = num_h_blocks * num_w_blocks
|
||||
|
|
|
@ -116,8 +116,8 @@ class LambdaLayer(nn.Module):
|
|||
v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
|
||||
k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M
|
||||
|
||||
content_lam = k @ v # B, K, V
|
||||
content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
|
||||
content_lam = torch.matmul(k, v) # B, K, V
|
||||
content_out = torch.matmul(q, content_lam.unsqueeze(1)) # B, num_heads, M, V
|
||||
|
||||
if self.pos_emb is None:
|
||||
position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
|
||||
|
|
|
@ -10,6 +10,7 @@ from torch.nn import functional as F
|
|||
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .helpers import make_divisible
|
||||
from timm.models.fx_helpers import fx_and
|
||||
|
||||
|
||||
class NonLocalAttn(nn.Module):
|
||||
|
@ -83,7 +84,7 @@ class BilinearAttnTransform(nn.Module):
|
|||
|
||||
def resize_mat(self, x, t: int):
|
||||
B, C, block_size, block_size1 = x.shape
|
||||
assert block_size == block_size1
|
||||
torch._assert(block_size == block_size1, '')
|
||||
if t <= 1:
|
||||
return x
|
||||
x = x.view(B * C, -1, 1, 1)
|
||||
|
@ -95,7 +96,7 @@ class BilinearAttnTransform(nn.Module):
|
|||
return x
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0
|
||||
torch._assert(fx_and(x.shape[-1] % self.block_size == 0, x.shape[-2] % self.block_size == 0), '')
|
||||
B, C, H, W = x.shape
|
||||
out = self.conv1(x)
|
||||
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
|
||||
|
|
|
@ -9,7 +9,11 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||
from torch import nn as nn
|
||||
|
||||
from .helpers import to_2tuple
|
||||
<<<<<<< HEAD
|
||||
from .trace_utils import _assert
|
||||
=======
|
||||
from timm.models.fx_helpers import fx_and
|
||||
>>>>>>> Make all models FX traceable
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
|
|
|
@ -34,7 +34,7 @@ class SelectiveKernelAttn(nn.Module):
|
|||
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.num_paths
|
||||
torch._assert(x.shape[1] == self.num_paths, '')
|
||||
x = x.sum(1).mean((2, 3), keepdim=True)
|
||||
x = self.fc_reduce(x)
|
||||
x = self.bn(x)
|
||||
|
|
|
@ -0,0 +1,183 @@
|
|||
""" Shifted Window Attn
|
||||
|
||||
This is a WIP experiment to apply windowed attention from the Swin Transformer
|
||||
to a stand-alone module for use as an attn block in conv nets.
|
||||
|
||||
Based on original swin window code at https://github.com/microsoft/Swin-Transformer
|
||||
Swin Transformer paper: https://arxiv.org/pdf/2103.14030.pdf
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .drop import DropPath
|
||||
from .helpers import to_2tuple
|
||||
from .weight_init import trunc_normal_
|
||||
from timm.models.fx_helpers import fx_float_to_int
|
||||
|
||||
|
||||
def window_partition(x, win_size: int):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
win_size (int): window size
|
||||
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // win_size, win_size, W // win_size, win_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, win_size: int, H: int, W: int):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
win_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
B = fx_float_to_int(windows.shape[0] / (H * W / win_size / win_size))
|
||||
x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
win_size (int): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim, dim_out=None, feat_size=None, stride=1, win_size=8, shift_size=None, num_heads=8,
|
||||
qkv_bias=True, attn_drop=0.):
|
||||
|
||||
super().__init__()
|
||||
self.dim_out = dim_out or dim
|
||||
self.feat_size = to_2tuple(feat_size)
|
||||
self.win_size = win_size
|
||||
self.shift_size = shift_size or win_size // 2
|
||||
if min(self.feat_size) <= win_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
self.shift_size = 0
|
||||
self.win_size = min(self.feat_size)
|
||||
assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-window_size"
|
||||
self.num_heads = num_heads
|
||||
head_dim = self.dim_out // num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
if self.shift_size > 0:
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = self.feat_size
|
||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||
h_slices = (
|
||||
slice(0, -self.win_size),
|
||||
slice(-self.win_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
w_slices = (
|
||||
slice(0, -self.win_size),
|
||||
slice(-self.win_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
mask_windows = window_partition(img_mask, self.win_size) # num_win, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.win_size * self.win_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
self.register_buffer("attn_mask", attn_mask)
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
# 2 * Wh - 1 * 2 * Ww - 1, nH
|
||||
torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads))
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.win_size)
|
||||
coords_w = torch.arange(self.win_size)
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += self.win_size - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += self.win_size - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.win_size - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
||||
|
||||
def reset_parameters(self):
|
||||
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
shifted_x = x
|
||||
|
||||
# partition windows
|
||||
win_size_sq = self.win_size * self.win_size
|
||||
x_windows = window_partition(shifted_x, self.win_size) # num_win * B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, win_size_sq, C) # num_win * B, window_size*window_size, C
|
||||
BW, N, _ = x_windows.shape
|
||||
|
||||
qkv = self.qkv(x_windows)
|
||||
qkv = qkv.reshape(BW, N, 3, self.num_heads, self.dim_out // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
q = q * self.scale
|
||||
attn = torch.matmul(q, k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(win_size_sq, win_size_sq, -1)
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh * Ww, Wh * Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
if self.attn_mask is not None:
|
||||
num_win = self.attn_mask.shape[0]
|
||||
attn = attn.view(B, num_win, self.num_heads, N, N) + self.attn_mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(BW, N, self.dim_out)
|
||||
|
||||
# merge windows
|
||||
x = x.view(-1, self.win_size, self.win_size, self.dim_out)
|
||||
shifted_x = window_reverse(x, self.win_size, H, W) # B H' W' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
x = x.view(B, H, W, self.dim_out).permute(0, 3, 1, 2)
|
||||
x = self.pool(x)
|
||||
return x
|
||||
|
||||
|
|
@ -293,10 +293,10 @@ class Attention(nn.Module):
|
|||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
|
||||
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
|
||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device)
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, self.dh)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
@ -387,10 +387,10 @@ class AttentionSubsample(nn.Module):
|
|||
v = v.permute(0, 2, 1, 3) # BHNC
|
||||
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
|
||||
|
||||
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
|
||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device)
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
|
||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, -1, self.dh)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
|
|
@ -26,10 +26,12 @@ from torch import nn
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, named_apply
|
||||
from .fx_helpers import fx_float_to_int
|
||||
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
|
||||
from .layers import create_conv2d, create_pool2d, to_ntuple
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -83,12 +85,12 @@ class Attention(nn.Module):
|
|||
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
|
||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
# (B, H, T, N, C'), permute -> (B, T, N, C', H)
|
||||
x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C)
|
||||
x = torch.matmul(attn, v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x # (B, T, N, C)
|
||||
|
@ -128,8 +130,8 @@ class ConvPool(nn.Module):
|
|||
"""
|
||||
x is expected to have shape (B, C, H, W)
|
||||
"""
|
||||
assert x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims'
|
||||
assert x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims'
|
||||
torch._assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims')
|
||||
torch._assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims')
|
||||
x = self.conv(x)
|
||||
# Layer norm done over channel dim only
|
||||
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
@ -144,8 +146,8 @@ def blockify(x, block_size: int):
|
|||
block_size (int): edge length of a single square block in units of H, W
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
assert H % block_size == 0, '`block_size` must divide input height evenly'
|
||||
assert W % block_size == 0, '`block_size` must divide input width evenly'
|
||||
torch._assert(H % block_size == 0, '`block_size` must divide input height evenly')
|
||||
torch._assert(W % block_size == 0, '`block_size` must divide input width evenly')
|
||||
grid_height = H // block_size
|
||||
grid_width = W // block_size
|
||||
x = x.reshape(B, grid_height, block_size, grid_width, block_size, C)
|
||||
|
@ -160,7 +162,7 @@ def deblockify(x, block_size: int):
|
|||
block_size (int): edge length of a single square block in units of desired H, W
|
||||
"""
|
||||
B, T, _, C = x.shape
|
||||
grid_size = int(math.sqrt(T))
|
||||
grid_size = fx_float_to_int(math.sqrt(T))
|
||||
height = width = grid_size * block_size
|
||||
x = x.reshape(B, grid_size, grid_size, block_size, block_size, C)
|
||||
x = x.transpose(2, 3).reshape(B, height, width, C)
|
||||
|
|
|
@ -27,6 +27,7 @@ import torch.nn as nn
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from timm.models.fx_features import register_leaf_module
|
||||
from .registry import register_model
|
||||
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
|
||||
get_act_layer, get_act_fn, get_attn, make_divisible
|
||||
|
@ -318,6 +319,7 @@ class DownsampleAvg(nn.Module):
|
|||
return self.conv(self.pool(x))
|
||||
|
||||
|
||||
@register_leaf_module # FX feature extraction was giving different valued features. Perhaps to do with control flow?
|
||||
class NormFreeBlock(nn.Module):
|
||||
"""Normalization-Free pre-activation block.
|
||||
"""
|
||||
|
|
|
@ -10,6 +10,7 @@ Changes for timm, feature extraction, and rounded channel variant hacked togethe
|
|||
Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
from math import ceil
|
||||
|
@ -92,7 +93,7 @@ class LinearBottleneck(nn.Module):
|
|||
if self.use_shortcut:
|
||||
if self.drop_path is not None:
|
||||
x = self.drop_path(x)
|
||||
x[:, 0:self.in_channels] += shortcut
|
||||
x = torch.cat([x[:, 0:self.in_channels] + shortcut, x[:, self.in_channels:]], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
@ -22,10 +22,12 @@ import torch.utils.checkpoint as checkpoint
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
from .fx_helpers import fx_float_to_int
|
||||
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
|
||||
from .registry import register_model
|
||||
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -111,7 +113,7 @@ def window_reverse(windows, window_size: int, H: int, W: int):
|
|||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
B = fx_float_to_int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
@ -175,7 +177,7 @@ class WindowAttention(nn.Module):
|
|||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
attn = torch.matmul(q, k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||
|
@ -192,7 +194,7 @@ class WindowAttention(nn.Module):
|
|||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
@ -270,7 +272,7 @@ class SwinTransformerBlock(nn.Module):
|
|||
def forward(self, x):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
torch._assert(L == H * W, "input feature has wrong size")
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
|
@ -329,8 +331,8 @@ class PatchMerging(nn.Module):
|
|||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
||||
torch._assert(L == H * W, "input feature has wrong size")
|
||||
torch._assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
|
|
|
@ -9,10 +9,10 @@ https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models.helpers import build_model_with_cfg
|
||||
from timm.models.fx_helpers import fx_and
|
||||
from timm.models.layers import Mlp, DropPath, trunc_normal_
|
||||
from timm.models.layers.helpers import to_2tuple
|
||||
from timm.models.registry import register_model
|
||||
|
@ -64,11 +64,11 @@ class Attention(nn.Module):
|
|||
q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, -1)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
@ -109,7 +109,9 @@ class Block(nn.Module):
|
|||
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
|
||||
# outer
|
||||
B, N, C = patch_embed.size()
|
||||
patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))
|
||||
patch_embed = torch.cat(
|
||||
[patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))],
|
||||
dim=1)
|
||||
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
|
||||
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
|
||||
return pixel_embed, patch_embed
|
||||
|
@ -136,8 +138,8 @@ class PixelEmbed(nn.Module):
|
|||
|
||||
def forward(self, x, pixel_pos):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
torch._assert(fx_and(H == self.img_size[0], W == self.img_size[1]),
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||
x = self.proj(x)
|
||||
x = self.unfold(x)
|
||||
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
|
||||
|
|
|
@ -22,6 +22,7 @@ from functools import partial
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .layers import Mlp, DropPath, to_2tuple, trunc_normal_
|
||||
from .fx_features import register_leaf_module
|
||||
from .registry import register_model
|
||||
from .vision_transformer import Attention
|
||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||
|
@ -62,6 +63,7 @@ default_cfgs = {
|
|||
Size_ = Tuple[int, int]
|
||||
|
||||
|
||||
@register_leaf_module # FX can't symbolically trace control flow in forward method
|
||||
class LocallyGroupedAttn(nn.Module):
|
||||
""" LSA: self attention within a group
|
||||
"""
|
||||
|
@ -98,10 +100,10 @@ class LocallyGroupedAttn(nn.Module):
|
|||
qkv = self.qkv(x).reshape(
|
||||
B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C)
|
||||
attn = torch.matmul(attn, v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C)
|
||||
x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C)
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
@ -183,11 +185,11 @@ class GlobalSubSampleAttn(nn.Module):
|
|||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ from typing import Union, List, Dict, Any, cast
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .fx_features import register_leaf_module
|
||||
from .layers import ClassifierHead, ConvBnAct
|
||||
from .registry import register_model
|
||||
|
||||
|
@ -52,6 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = {
|
|||
}
|
||||
|
||||
|
||||
@register_leaf_module # FX can't symbolically trace control flow in forward method
|
||||
class ConvMlp(nn.Module):
|
||||
|
||||
def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,
|
||||
|
|
|
@ -100,10 +100,10 @@ class Attention(nn.Module):
|
|||
x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
|
||||
q, k, v = x[0], x[1], x[2]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
x = torch.matmul(attn, v)
|
||||
|
||||
x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W)
|
||||
x = self.proj(x)
|
||||
|
|
|
@ -192,11 +192,11 @@ class Attention(nn.Module):
|
|||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
|
|
@ -21,6 +21,7 @@ from .vision_transformer import _cfg, Mlp
|
|||
from .registry import register_model
|
||||
from .layers import DropPath, trunc_normal_, to_2tuple
|
||||
from .cait import ClassAttn
|
||||
from .fx_features import register_leaf_module
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
|
@ -97,6 +98,7 @@ default_cfgs = {
|
|||
}
|
||||
|
||||
|
||||
@register_leaf_module # FX can't symbolically trace torch.arange in forward method
|
||||
class PositionalEncodingFourier(nn.Module):
|
||||
"""
|
||||
Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.
|
||||
|
@ -272,12 +274,12 @@ class XCA(nn.Module):
|
|||
# Paper section 3.2 l2-Normalization and temperature scaling
|
||||
q = torch.nn.functional.normalize(q, dim=-1)
|
||||
k = torch.nn.functional.normalize(k, dim=-1)
|
||||
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||
attn = torch.matmul(q, k.transpose(-2, -1)) * self.temperature
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
# (B, H, C', N), permute -> (B, N, H, C')
|
||||
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
|
||||
x = torch.matmul(attn, v).permute(0, 3, 1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
|
Loading…
Reference in New Issue