wip -rebase

This commit is contained in:
Alexander Soare 2021-11-07 15:04:19 +00:00
parent ab3ac3f25b
commit bc3d4eb403
23 changed files with 269 additions and 63 deletions

View File

@ -95,11 +95,11 @@ class ClassAttn(nn.Module):
q = q * self.scale q = q * self.scale
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) attn = torch.matmul(q, k.transpose(-2, -1))
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) x_cls = torch.matmul(attn, v).transpose(1, 2).reshape(B, 1, C)
x_cls = self.proj(x_cls) x_cls = self.proj(x_cls)
x_cls = self.proj_drop(x_cls) x_cls = self.proj_drop(x_cls)
@ -158,7 +158,7 @@ class TalkingHeadAttn(nn.Module):
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) attn = torch.matmul(q, k.transpose(-2, -1))
attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
@ -167,7 +167,7 @@ class TalkingHeadAttn(nn.Module):
attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x

View File

@ -105,7 +105,7 @@ class ConvRelPosEnc(nn.Module):
def forward(self, q, v, size: Tuple[int, int]): def forward(self, q, v, size: Tuple[int, int]):
B, h, N, Ch = q.shape B, h, N, Ch = q.shape
H, W = size H, W = size
assert N == 1 + H * W torch._assert(N == 1 + H * W, '')
# Convolutional relative position encoding. # Convolutional relative position encoding.
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch] q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
@ -149,8 +149,8 @@ class FactorAtt_ConvRelPosEnc(nn.Module):
# Factorized attention. # Factorized attention.
k_softmax = k.softmax(dim=2) k_softmax = k.softmax(dim=2)
factor_att = k_softmax.transpose(-1, -2) @ v factor_att = torch.matmul(k_softmax.transpose(-1, -2), v)
factor_att = q @ factor_att factor_att = torch.matmul(q, factor_att)
# Convolutional relative position encoding. # Convolutional relative position encoding.
crpe = self.crpe(q, v, size=size) # [B, h, N, Ch] crpe = self.crpe(q, v, size=size) # [B, h, N, Ch]
@ -177,7 +177,7 @@ class ConvPosEnc(nn.Module):
def forward(self, x, size: Tuple[int, int]): def forward(self, x, size: Tuple[int, int]):
B, N, C = x.shape B, N, C = x.shape
H, W = size H, W = size
assert N == 1 + H * W torch._assert(N == 1 + H * W, '')
# Extract CLS token and image tokens. # Extract CLS token and image tokens.
cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C] cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
@ -275,7 +275,7 @@ class ParallelBlock(nn.Module):
""" Feature map interpolation. """ """ Feature map interpolation. """
B, N, C = x.shape B, N, C = x.shape
H, W = size H, W = size
assert N == 1 + H * W torch._assert(N == 1 + H * W, '')
cls_token = x[:, :1, :] cls_token = x[:, :1, :]
img_tokens = x[:, 1:, :] img_tokens = x[:, 1:, :]

View File

@ -30,6 +30,7 @@ from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
from .registry import register_model from .registry import register_model
from .vision_transformer_hybrid import HybridEmbed from .vision_transformer_hybrid import HybridEmbed
from .fx_features import register_leaf_module
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -56,6 +57,7 @@ default_cfgs = {
} }
@register_leaf_module # FX can't symbolically trace control flow in forward method
class GPSA(nn.Module): class GPSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
locality_strength=1.): locality_strength=1.):
@ -82,7 +84,7 @@ class GPSA(nn.Module):
self.rel_indices = self.get_rel_indices(N) self.rel_indices = self.get_rel_indices(N)
attn = self.get_attention(x) attn = self.get_attention(x)
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
@ -93,7 +95,7 @@ class GPSA(nn.Module):
q, k = qk[0], qk[1] q, k = qk[0], qk[1]
pos_score = self.rel_indices.expand(B, -1, -1, -1) pos_score = self.rel_indices.expand(B, -1, -1, -1)
pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2) pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
patch_score = (q @ k.transpose(-2, -1)) * self.scale patch_score = torch.matmul(q, k.transpose(-2, -1)) * self.scale
patch_score = patch_score.softmax(dim=-1) patch_score = patch_score.softmax(dim=-1)
pos_score = pos_score.softmax(dim=-1) pos_score = pos_score.softmax(dim=-1)
@ -178,11 +180,11 @@ class MHSA(nn.Module):
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x

View File

@ -22,6 +22,7 @@ import torch.nn.functional as F
from .helpers import to_2tuple, make_divisible from .helpers import to_2tuple, make_divisible
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_
from timm.models.fx_helpers import fx_and
def rel_logits_1d(q, rel_k, permute_mask: List[int]): def rel_logits_1d(q, rel_k, permute_mask: List[int]):
@ -36,7 +37,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]):
permute_mask: permute output dim according to this permute_mask: permute output dim according to this
""" """
B, H, W, dim = q.shape B, H, W, dim = q.shape
x = (q @ rel_k.transpose(-1, -2)) x = torch.matmul(q, rel_k.transpose(-1, -2))
x = x.reshape(-1, W, 2 * W -1) x = x.reshape(-1, W, 2 * W -1)
# pad to shift from relative to absolute indexing # pad to shift from relative to absolute indexing
@ -133,8 +134,8 @@ class BottleneckAttn(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == self.pos_embed.height torch._assert(H == self.pos_embed.height, '')
assert W == self.pos_embed.width torch._assert(W == self.pos_embed.width, '')
x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
@ -154,5 +155,3 @@ class BottleneckAttn(nn.Module):
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
out = self.pool(out) out = self.pool(out)
return out return out

View File

@ -72,9 +72,9 @@ class EvoNormSample2d(nn.Module):
nn.init.ones_(self.v) nn.init.ones_(self.v)
def forward(self, x): def forward(self, x):
assert x.dim() == 4, 'expected 4D input' torch._assert(x.dim() == 4, 'expected 4D input')
B, C, H, W = x.shape B, C, H, W = x.shape
assert C % self.groups == 0 torch._assert(C % self.groups == 0, '')
if self.apply_act: if self.apply_act:
n = x * (x * self.v).sigmoid() n = x * (x * self.v).sigmoid()
x = x.reshape(B, self.groups, -1) x = x.reshape(B, self.groups, -1)

View File

@ -7,6 +7,7 @@ Official code consulted as reference: https://github.com/xvjiarui/GCNet
Hacked together by / Copyright 2021 Ross Wightman Hacked together by / Copyright 2021 Ross Wightman
""" """
import torch
from torch import nn as nn from torch import nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -52,7 +53,7 @@ class GlobalContext(nn.Module):
if self.conv_attn is not None: if self.conv_attn is not None:
attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
context = x.reshape(B, C, H * W).unsqueeze(1) @ attn context = torch.matmul(x.reshape(B, C, H * W).unsqueeze(1), attn)
context = context.view(B, C, 1, 1) context = context.view(B, C, 1, 1)
else: else:
context = x.mean(dim=(2, 3), keepdim=True) context = x.mean(dim=(2, 3), keepdim=True)

View File

@ -24,6 +24,7 @@ import torch.nn.functional as F
from .helpers import make_divisible from .helpers import make_divisible
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_
from timm.models.fx_helpers import fx_and
def rel_logits_1d(q, rel_k, permute_mask: List[int]): def rel_logits_1d(q, rel_k, permute_mask: List[int]):
@ -41,7 +42,7 @@ def rel_logits_1d(q, rel_k, permute_mask: List[int]):
rel_size = rel_k.shape[0] rel_size = rel_k.shape[0]
win_size = (rel_size + 1) // 2 win_size = (rel_size + 1) // 2
x = (q @ rel_k.transpose(-1, -2)) x = torch.matmul(q, rel_k.transpose(-1, -2))
x = x.reshape(-1, W, rel_size) x = x.reshape(-1, W, rel_size)
# pad to shift from relative to absolute indexing # pad to shift from relative to absolute indexing
@ -167,8 +168,8 @@ class HaloAttn(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H % self.block_size == 0 torch._assert(H % self.block_size == 0, '')
assert W % self.block_size == 0 torch._assert(W % self.block_size == 0, '')
num_h_blocks = H // self.block_size num_h_blocks = H // self.block_size
num_w_blocks = W // self.block_size num_w_blocks = W // self.block_size
num_blocks = num_h_blocks * num_w_blocks num_blocks = num_h_blocks * num_w_blocks

View File

@ -116,8 +116,8 @@ class LambdaLayer(nn.Module):
v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M
content_lam = k @ v # B, K, V content_lam = torch.matmul(k, v) # B, K, V
content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V content_out = torch.matmul(q, content_lam.unsqueeze(1)) # B, num_heads, M, V
if self.pos_emb is None: if self.pos_emb is None:
position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K

View File

@ -10,6 +10,7 @@ from torch.nn import functional as F
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvBnAct
from .helpers import make_divisible from .helpers import make_divisible
from timm.models.fx_helpers import fx_and
class NonLocalAttn(nn.Module): class NonLocalAttn(nn.Module):
@ -83,7 +84,7 @@ class BilinearAttnTransform(nn.Module):
def resize_mat(self, x, t: int): def resize_mat(self, x, t: int):
B, C, block_size, block_size1 = x.shape B, C, block_size, block_size1 = x.shape
assert block_size == block_size1 torch._assert(block_size == block_size1, '')
if t <= 1: if t <= 1:
return x return x
x = x.view(B * C, -1, 1, 1) x = x.view(B * C, -1, 1, 1)
@ -95,7 +96,7 @@ class BilinearAttnTransform(nn.Module):
return x return x
def forward(self, x): def forward(self, x):
assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0 torch._assert(fx_and(x.shape[-1] % self.block_size == 0, x.shape[-2] % self.block_size == 0), '')
B, C, H, W = x.shape B, C, H, W = x.shape
out = self.conv1(x) out = self.conv1(x)
rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) rp = F.adaptive_max_pool2d(out, (self.block_size, 1))

View File

@ -9,7 +9,11 @@ Hacked together by / Copyright 2020 Ross Wightman
from torch import nn as nn from torch import nn as nn
from .helpers import to_2tuple from .helpers import to_2tuple
<<<<<<< HEAD
from .trace_utils import _assert from .trace_utils import _assert
=======
from timm.models.fx_helpers import fx_and
>>>>>>> Make all models FX traceable
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):

View File

@ -34,7 +34,7 @@ class SelectiveKernelAttn(nn.Module):
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.num_paths torch._assert(x.shape[1] == self.num_paths, '')
x = x.sum(1).mean((2, 3), keepdim=True) x = x.sum(1).mean((2, 3), keepdim=True)
x = self.fc_reduce(x) x = self.fc_reduce(x)
x = self.bn(x) x = self.bn(x)

View File

@ -0,0 +1,183 @@
""" Shifted Window Attn
This is a WIP experiment to apply windowed attention from the Swin Transformer
to a stand-alone module for use as an attn block in conv nets.
Based on original swin window code at https://github.com/microsoft/Swin-Transformer
Swin Transformer paper: https://arxiv.org/pdf/2103.14030.pdf
"""
from typing import Optional
import torch
import torch.nn as nn
from .drop import DropPath
from .helpers import to_2tuple
from .weight_init import trunc_normal_
from timm.models.fx_helpers import fx_float_to_int
def window_partition(x, win_size: int):
"""
Args:
x: (B, H, W, C)
win_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // win_size, win_size, W // win_size, win_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C)
return windows
def window_reverse(windows, win_size: int, H: int, W: int):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
win_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = fx_float_to_int(windows.shape[0] / (H * W / win_size / win_size))
x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
win_size (int): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
"""
def __init__(
self, dim, dim_out=None, feat_size=None, stride=1, win_size=8, shift_size=None, num_heads=8,
qkv_bias=True, attn_drop=0.):
super().__init__()
self.dim_out = dim_out or dim
self.feat_size = to_2tuple(feat_size)
self.win_size = win_size
self.shift_size = shift_size or win_size // 2
if min(self.feat_size) <= win_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.win_size = min(self.feat_size)
assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-window_size"
self.num_heads = num_heads
head_dim = self.dim_out // num_heads
self.scale = head_dim ** -0.5
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.feat_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -self.win_size),
slice(-self.win_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (
slice(0, -self.win_size),
slice(-self.win_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.win_size) # num_win, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.win_size * self.win_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
# 2 * Wh - 1 * 2 * Ww - 1, nH
torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads))
trunc_normal_(self.relative_position_bias_table, std=.02)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.win_size)
coords_w = torch.arange(self.win_size)
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.win_size - 1 # shift to start from 0
relative_coords[:, :, 1] += self.win_size - 1
relative_coords[:, :, 0] *= 2 * self.win_size - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.softmax = nn.Softmax(dim=-1)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self, x):
B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
win_size_sq = self.win_size * self.win_size
x_windows = window_partition(shifted_x, self.win_size) # num_win * B, window_size, window_size, C
x_windows = x_windows.view(-1, win_size_sq, C) # num_win * B, window_size*window_size, C
BW, N, _ = x_windows.shape
qkv = self.qkv(x_windows)
qkv = qkv.reshape(BW, N, 3, self.num_heads, self.dim_out // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = torch.matmul(q, k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(win_size_sq, win_size_sq, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh * Ww, Wh * Ww
attn = attn + relative_position_bias.unsqueeze(0)
if self.attn_mask is not None:
num_win = self.attn_mask.shape[0]
attn = attn.view(B, num_win, self.num_heads, N, N) + self.attn_mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = torch.matmul(attn, v).transpose(1, 2).reshape(BW, N, self.dim_out)
# merge windows
x = x.view(-1, self.win_size, self.win_size, self.dim_out)
shifted_x = window_reverse(x, self.win_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H, W, self.dim_out).permute(0, 3, 1, 2)
x = self.pool(x)
return x

View File

@ -293,10 +293,10 @@ class Attention(nn.Module):
k = k.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3)
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x) x = self.proj(x)
return x return x
@ -387,10 +387,10 @@ class AttentionSubsample(nn.Module):
v = v.permute(0, 2, 1, 3) # BHNC v = v.permute(0, 2, 1, 3) # BHNC
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) x = torch.matmul(attn, v).transpose(1, 2).reshape(B, -1, self.dh)
x = self.proj(x) x = self.proj(x)
return x return x

View File

@ -26,10 +26,12 @@ from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .fx_helpers import fx_float_to_int
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
from .layers import create_conv2d, create_pool2d, to_ntuple from .layers import create_conv2d, create_pool2d, to_ntuple
from .registry import register_model from .registry import register_model
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -83,12 +85,12 @@ class Attention(nn.Module):
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5) qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N) attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
# (B, H, T, N, C'), permute -> (B, T, N, C', H) # (B, H, T, N, C'), permute -> (B, T, N, C', H)
x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C) x = torch.matmul(attn, v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x # (B, T, N, C) return x # (B, T, N, C)
@ -128,8 +130,8 @@ class ConvPool(nn.Module):
""" """
x is expected to have shape (B, C, H, W) x is expected to have shape (B, C, H, W)
""" """
assert x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims' torch._assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims')
assert x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims' torch._assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims')
x = self.conv(x) x = self.conv(x)
# Layer norm done over channel dim only # Layer norm done over channel dim only
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
@ -144,8 +146,8 @@ def blockify(x, block_size: int):
block_size (int): edge length of a single square block in units of H, W block_size (int): edge length of a single square block in units of H, W
""" """
B, H, W, C = x.shape B, H, W, C = x.shape
assert H % block_size == 0, '`block_size` must divide input height evenly' torch._assert(H % block_size == 0, '`block_size` must divide input height evenly')
assert W % block_size == 0, '`block_size` must divide input width evenly' torch._assert(W % block_size == 0, '`block_size` must divide input width evenly')
grid_height = H // block_size grid_height = H // block_size
grid_width = W // block_size grid_width = W // block_size
x = x.reshape(B, grid_height, block_size, grid_width, block_size, C) x = x.reshape(B, grid_height, block_size, grid_width, block_size, C)
@ -160,7 +162,7 @@ def deblockify(x, block_size: int):
block_size (int): edge length of a single square block in units of desired H, W block_size (int): edge length of a single square block in units of desired H, W
""" """
B, T, _, C = x.shape B, T, _, C = x.shape
grid_size = int(math.sqrt(T)) grid_size = fx_float_to_int(math.sqrt(T))
height = width = grid_size * block_size height = width = grid_size * block_size
x = x.reshape(B, grid_size, grid_size, block_size, block_size, C) x = x.reshape(B, grid_size, grid_size, block_size, block_size, C)
x = x.transpose(2, 3).reshape(B, height, width, C) x = x.transpose(2, 3).reshape(B, height, width, C)

View File

@ -27,6 +27,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from timm.models.fx_features import register_leaf_module
from .registry import register_model from .registry import register_model
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
get_act_layer, get_act_fn, get_attn, make_divisible get_act_layer, get_act_fn, get_attn, make_divisible
@ -318,6 +319,7 @@ class DownsampleAvg(nn.Module):
return self.conv(self.pool(x)) return self.conv(self.pool(x))
@register_leaf_module # FX feature extraction was giving different valued features. Perhaps to do with control flow?
class NormFreeBlock(nn.Module): class NormFreeBlock(nn.Module):
"""Normalization-Free pre-activation block. """Normalization-Free pre-activation block.
""" """

View File

@ -10,6 +10,7 @@ Changes for timm, feature extraction, and rounded channel variant hacked togethe
Copyright 2020 Ross Wightman Copyright 2020 Ross Wightman
""" """
import torch
import torch.nn as nn import torch.nn as nn
from functools import partial from functools import partial
from math import ceil from math import ceil
@ -92,7 +93,7 @@ class LinearBottleneck(nn.Module):
if self.use_shortcut: if self.use_shortcut:
if self.drop_path is not None: if self.drop_path is not None:
x = self.drop_path(x) x = self.drop_path(x)
x[:, 0:self.in_channels] += shortcut x = torch.cat([x[:, 0:self.in_channels] + shortcut, x[:, self.in_channels:]], dim=1)
return x return x

View File

@ -22,10 +22,12 @@ import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .fx_helpers import fx_float_to_int
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .registry import register_model from .registry import register_model
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -111,7 +113,7 @@ def window_reverse(windows, window_size: int, H: int, W: int):
Returns: Returns:
x: (B, H, W, C) x: (B, H, W, C)
""" """
B = int(windows.shape[0] / (H * W / window_size / window_size)) B = fx_float_to_int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x return x
@ -175,7 +177,7 @@ class WindowAttention(nn.Module):
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale q = q * self.scale
attn = (q @ k.transpose(-2, -1)) attn = torch.matmul(q, k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
@ -192,7 +194,7 @@ class WindowAttention(nn.Module):
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = torch.matmul(attn, v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
@ -270,7 +272,7 @@ class SwinTransformerBlock(nn.Module):
def forward(self, x): def forward(self, x):
H, W = self.input_resolution H, W = self.input_resolution
B, L, C = x.shape B, L, C = x.shape
assert L == H * W, "input feature has wrong size" torch._assert(L == H * W, "input feature has wrong size")
shortcut = x shortcut = x
x = self.norm1(x) x = self.norm1(x)
@ -329,8 +331,8 @@ class PatchMerging(nn.Module):
""" """
H, W = self.input_resolution H, W = self.input_resolution
B, L, C = x.shape B, L, C = x.shape
assert L == H * W, "input feature has wrong size" torch._assert(L == H * W, "input feature has wrong size")
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." torch._assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")
x = x.view(B, H, W, C) x = x.view(B, H, W, C)

View File

@ -9,10 +9,10 @@ https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import build_model_with_cfg from timm.models.helpers import build_model_with_cfg
from timm.models.fx_helpers import fx_and
from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers import Mlp, DropPath, trunc_normal_
from timm.models.layers.helpers import to_2tuple from timm.models.layers.helpers import to_2tuple
from timm.models.registry import register_model from timm.models.registry import register_model
@ -64,11 +64,11 @@ class Attention(nn.Module):
q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple) q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
@ -109,7 +109,9 @@ class Block(nn.Module):
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
# outer # outer
B, N, C = patch_embed.size() B, N, C = patch_embed.size()
patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)) patch_embed = torch.cat(
[patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))],
dim=1)
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed))) patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
return pixel_embed, patch_embed return pixel_embed, patch_embed
@ -136,8 +138,8 @@ class PixelEmbed(nn.Module):
def forward(self, x, pixel_pos): def forward(self, x, pixel_pos):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \ torch._assert(fx_and(H == self.img_size[0], W == self.img_size[1]),
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
x = self.proj(x) x = self.proj(x)
x = self.unfold(x) x = self.unfold(x)
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])

View File

@ -22,6 +22,7 @@ from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import Mlp, DropPath, to_2tuple, trunc_normal_
from .fx_features import register_leaf_module
from .registry import register_model from .registry import register_model
from .vision_transformer import Attention from .vision_transformer import Attention
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg, overlay_external_default_cfg
@ -62,6 +63,7 @@ default_cfgs = {
Size_ = Tuple[int, int] Size_ = Tuple[int, int]
@register_leaf_module # FX can't symbolically trace control flow in forward method
class LocallyGroupedAttn(nn.Module): class LocallyGroupedAttn(nn.Module):
""" LSA: self attention within a group """ LSA: self attention within a group
""" """
@ -98,10 +100,10 @@ class LocallyGroupedAttn(nn.Module):
qkv = self.qkv(x).reshape( qkv = self.qkv(x).reshape(
B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
q, k, v = qkv[0], qkv[1], qkv[2] q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) attn = torch.matmul(attn, v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C)
x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C)
if pad_r > 0 or pad_b > 0: if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous() x = x[:, :H, :W, :].contiguous()
@ -183,11 +185,11 @@ class GlobalSubSampleAttn(nn.Module):
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1] k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)

View File

@ -12,6 +12,7 @@ from typing import Union, List, Dict, Any, cast
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .fx_features import register_leaf_module
from .layers import ClassifierHead, ConvBnAct from .layers import ClassifierHead, ConvBnAct
from .registry import register_model from .registry import register_model
@ -52,6 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = {
} }
@register_leaf_module # FX can't symbolically trace control flow in forward method
class ConvMlp(nn.Module): class ConvMlp(nn.Module):
def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,

View File

@ -100,10 +100,10 @@ class Attention(nn.Module):
x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3) x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
q, k, v = x[0], x[1], x[2] q, k, v = x[0], x[1], x[2]
attn = (q @ k.transpose(-2, -1)) * self.scale attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = attn @ v x = torch.matmul(attn, v)
x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W) x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W)
x = self.proj(x) x = self.proj(x)

View File

@ -192,11 +192,11 @@ class Attention(nn.Module):
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = torch.matmul(attn, v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x

View File

@ -21,6 +21,7 @@ from .vision_transformer import _cfg, Mlp
from .registry import register_model from .registry import register_model
from .layers import DropPath, trunc_normal_, to_2tuple from .layers import DropPath, trunc_normal_, to_2tuple
from .cait import ClassAttn from .cait import ClassAttn
from .fx_features import register_leaf_module
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -97,6 +98,7 @@ default_cfgs = {
} }
@register_leaf_module # FX can't symbolically trace torch.arange in forward method
class PositionalEncodingFourier(nn.Module): class PositionalEncodingFourier(nn.Module):
""" """
Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper. Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.
@ -272,12 +274,12 @@ class XCA(nn.Module):
# Paper section 3.2 l2-Normalization and temperature scaling # Paper section 3.2 l2-Normalization and temperature scaling
q = torch.nn.functional.normalize(q, dim=-1) q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1) k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature attn = torch.matmul(q, k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
# (B, H, C', N), permute -> (B, N, H, C') # (B, H, C', N), permute -> (B, N, H, C')
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) x = torch.matmul(attn, v).permute(0, 3, 1, 2).reshape(B, N, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x