TinyVitBlock needs adding as leaf for FX now, tweak a few dim names
parent
9caf32b93f
commit
507cb08acf
|
@ -21,6 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||
from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\
|
||||
trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
|
@ -178,18 +179,15 @@ class Attention(torch.nn.Module):
|
|||
self.num_heads = num_heads
|
||||
self.scale = key_dim ** -0.5
|
||||
self.key_dim = key_dim
|
||||
self.nh_kd = nh_kd = key_dim * num_heads
|
||||
self.d = int(attn_ratio * key_dim)
|
||||
self.dh = int(attn_ratio * key_dim) * num_heads
|
||||
self.val_dim = int(attn_ratio * key_dim)
|
||||
self.out_dim = self.val_dim * num_heads
|
||||
self.attn_ratio = attn_ratio
|
||||
self.resolution = resolution
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
h = self.dh + nh_kd * 2
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.qkv = nn.Linear(dim, h)
|
||||
self.proj = nn.Linear(self.dh, dim)
|
||||
self.qkv = nn.Linear(dim, num_heads * (self.val_dim + 2 * key_dim))
|
||||
self.proj = nn.Linear(self.out_dim, dim)
|
||||
|
||||
points = list(itertools.product(range(resolution[0]), range(resolution[1])))
|
||||
N = len(points)
|
||||
|
@ -227,7 +225,7 @@ class Attention(torch.nn.Module):
|
|||
x = self.norm(x)
|
||||
qkv = self.qkv(x)
|
||||
# (B, N, num_heads, d)
|
||||
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
|
||||
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
|
||||
# (B, num_heads, N, d)
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
|
@ -241,7 +239,7 @@ class Attention(torch.nn.Module):
|
|||
attn = attn + attn_bias
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = attn @ v
|
||||
x = x.transpose(1, 2).reshape(B, N, self.dh)
|
||||
x = x.transpose(1, 2).reshape(B, N, self.out_dim)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
@ -311,7 +309,6 @@ class TinyVitBlock(nn.Module):
|
|||
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||
padding = pad_b > 0 or pad_r > 0
|
||||
|
||||
if padding:
|
||||
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
|
||||
|
||||
|
@ -344,6 +341,9 @@ class TinyVitBlock(nn.Module):
|
|||
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
|
||||
|
||||
|
||||
register_notrace_module(TinyVitBlock)
|
||||
|
||||
|
||||
class TinyVitStage(nn.Module):
|
||||
""" A basic TinyViT layer for one stage.
|
||||
|
||||
|
|
Loading…
Reference in New Issue