mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Updated tnt model weights on hub, add back legacy model in case bwd compat
This commit is contained in:
parent
69b1fbcdc1
commit
74ad32a67e
@ -22,13 +22,13 @@ from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
|
||||
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
""" Multi-Head Attention
|
||||
"""
|
||||
|
||||
def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
@ -46,7 +46,7 @@ class Attention(nn.Module):
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
@ -62,6 +62,7 @@ class Attention(nn.Module):
|
||||
class Block(nn.Module):
|
||||
""" TNT Block
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
@ -89,7 +90,7 @@ class Block(nn.Module):
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
|
||||
|
||||
self.norm_mlp_in = norm_layer(dim)
|
||||
self.mlp_in = Mlp(
|
||||
in_features=dim,
|
||||
@ -118,7 +119,7 @@ class Block(nn.Module):
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
|
||||
self.norm_mlp = norm_layer(dim_out)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim_out,
|
||||
@ -136,13 +137,13 @@ class Block(nn.Module):
|
||||
B, N, C = patch_embed.size()
|
||||
if self.legacy:
|
||||
patch_embed = torch.cat([
|
||||
patch_embed[:, 0:1], patch_embed[:, 1:] + \
|
||||
self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),
|
||||
patch_embed[:, 0:1],
|
||||
patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),
|
||||
], dim=1)
|
||||
else:
|
||||
patch_embed = torch.cat([
|
||||
patch_embed[:, 0:1], patch_embed[:, 1:] + \
|
||||
self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, N - 1, -1)))),
|
||||
patch_embed[:, 0:1],
|
||||
patch_embed[:, 1:] + self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, N - 1, -1)))),
|
||||
], dim=1)
|
||||
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
|
||||
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
|
||||
@ -152,7 +153,16 @@ class Block(nn.Module):
|
||||
class PixelEmbed(nn.Module):
|
||||
""" Image to Pixel Embedding
|
||||
"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4, legacy=False):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
in_dim=48,
|
||||
stride=4,
|
||||
legacy=False,
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
@ -184,14 +194,17 @@ class PixelEmbed(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor, pixel_pos: torch.Tensor) -> torch.Tensor:
|
||||
B, C, H, W = x.shape
|
||||
_assert(H == self.img_size[0],
|
||||
_assert(
|
||||
H == self.img_size[0],
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||
_assert(W == self.img_size[1],
|
||||
_assert(
|
||||
W == self.img_size[1],
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||
if self.legacy:
|
||||
x = self.proj(x)
|
||||
x = self.unfold(x)
|
||||
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
|
||||
x = x.transpose(1, 2).reshape(
|
||||
B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
|
||||
else:
|
||||
x = self.unfold(x)
|
||||
x = x.transpose(1, 2).reshape(B * self.num_patches, C, self.patch_size[0], self.patch_size[1])
|
||||
@ -204,6 +217,7 @@ class PixelEmbed(nn.Module):
|
||||
class TNT(nn.Module):
|
||||
""" Transformer in Transformer - https://arxiv.org/abs/2103.00112
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
@ -248,7 +262,7 @@ class TNT(nn.Module):
|
||||
self.num_patches = num_patches
|
||||
new_patch_size = self.pixel_embed.new_patch_size
|
||||
num_pixel = new_patch_size[0] * new_patch_size[1]
|
||||
|
||||
|
||||
self.norm1_proj = norm_layer(num_pixel * inner_dim)
|
||||
self.proj = nn.Linear(num_pixel * inner_dim, embed_dim)
|
||||
self.norm2_proj = norm_layer(embed_dim)
|
||||
@ -278,7 +292,7 @@ class TNT(nn.Module):
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.feature_info = [
|
||||
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
|
||||
|
||||
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
@ -359,7 +373,7 @@ class TNT(nn.Module):
|
||||
B, _, height, width = x.shape
|
||||
|
||||
pixel_embed = self.pixel_embed(x, self.pixel_pos)
|
||||
|
||||
|
||||
patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
|
||||
patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
|
||||
patch_embed = patch_embed + self.patch_pos
|
||||
@ -381,7 +395,7 @@ class TNT(nn.Module):
|
||||
# split prefix (e.g. class, distill) and spatial feature tokens
|
||||
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
|
||||
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
|
||||
|
||||
|
||||
if reshape:
|
||||
# reshape to BCHW output format
|
||||
H, W = self.pixel_embed.dynamic_feat_size((height, width))
|
||||
@ -416,7 +430,7 @@ class TNT(nn.Module):
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
pixel_embed = self.pixel_embed(x, self.pixel_pos)
|
||||
|
||||
|
||||
patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
|
||||
patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
|
||||
patch_embed = patch_embed + self.patch_pos
|
||||
@ -458,42 +472,47 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'tnt_s_legacy_patch16_224.in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
|
||||
),
|
||||
'tnt_s_patch16_224.in1k': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
# url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
|
||||
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
|
||||
),
|
||||
'tnt_b_patch16_224.in1k': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
|
||||
hf_hub_id='timm/',
|
||||
#url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
state_dict.pop('outer_tokens', None)
|
||||
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
k = k.replace('outer_pos', 'patch_pos')
|
||||
k = k.replace('inner_pos', 'pixel_pos')
|
||||
k = k.replace('patch_embed', 'pixel_embed')
|
||||
k = k.replace('proj_norm1', 'norm1_proj')
|
||||
k = k.replace('proj_norm2', 'norm2_proj')
|
||||
k = k.replace('inner_norm1', 'norm_in')
|
||||
k = k.replace('inner_attn', 'attn_in')
|
||||
k = k.replace('inner_norm2', 'norm_mlp_in')
|
||||
k = k.replace('inner_mlp', 'mlp_in')
|
||||
k = k.replace('outer_norm1', 'norm_out')
|
||||
k = k.replace('outer_attn', 'attn_out')
|
||||
k = k.replace('outer_norm2', 'norm_mlp')
|
||||
k = k.replace('outer_mlp', 'mlp')
|
||||
if k == 'pixel_pos' and model.pixel_embed.legacy == False:
|
||||
B, N, C = v.shape
|
||||
H = W = int(N ** 0.5)
|
||||
assert H * W == N
|
||||
v = v.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
out_dict[k] = v
|
||||
if 'patch_pos' in state_dict:
|
||||
out_dict = state_dict
|
||||
else:
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
k = k.replace('outer_pos', 'patch_pos')
|
||||
k = k.replace('inner_pos', 'pixel_pos')
|
||||
k = k.replace('patch_embed', 'pixel_embed')
|
||||
k = k.replace('proj_norm1', 'norm1_proj')
|
||||
k = k.replace('proj_norm2', 'norm2_proj')
|
||||
k = k.replace('inner_norm1', 'norm_in')
|
||||
k = k.replace('inner_attn', 'attn_in')
|
||||
k = k.replace('inner_norm2', 'norm_mlp_in')
|
||||
k = k.replace('inner_mlp', 'mlp_in')
|
||||
k = k.replace('outer_norm1', 'norm_out')
|
||||
k = k.replace('outer_attn', 'attn_out')
|
||||
k = k.replace('outer_norm2', 'norm_mlp')
|
||||
k = k.replace('outer_mlp', 'mlp')
|
||||
if k == 'pixel_pos' and model.pixel_embed.legacy == False:
|
||||
B, N, C = v.shape
|
||||
H = W = int(N ** 0.5)
|
||||
assert H * W == N
|
||||
v = v.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
out_dict[k] = v
|
||||
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
if out_dict['patch_pos'].shape != model.patch_pos.shape:
|
||||
@ -515,6 +534,15 @@ def _create_tnt(variant, pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tnt_s_legacy_patch16_224(pretrained=False, **kwargs) -> TNT:
|
||||
model_cfg = dict(
|
||||
patch_size=16, embed_dim=384, inner_dim=24, depth=12, num_heads_outer=6,
|
||||
qkv_bias=False, legacy=True)
|
||||
model = _create_tnt('tnt_s_legacy_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT:
|
||||
model_cfg = dict(
|
||||
|
Loading…
x
Reference in New Issue
Block a user