mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Updated tnt model weights on hub, add back legacy model in case bwd compat
This commit is contained in:
parent
69b1fbcdc1
commit
74ad32a67e
@ -22,13 +22,13 @@ from ._features import feature_take_indices
|
|||||||
from ._manipulate import checkpoint
|
from ._manipulate import checkpoint
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this
|
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
""" Multi-Head Attention
|
""" Multi-Head Attention
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_dim = hidden_dim
|
self.hidden_dim = hidden_dim
|
||||||
@ -46,7 +46,7 @@ class Attention(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||||
q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
@ -62,6 +62,7 @@ class Attention(nn.Module):
|
|||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
""" TNT Block
|
""" TNT Block
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
@ -136,13 +137,13 @@ class Block(nn.Module):
|
|||||||
B, N, C = patch_embed.size()
|
B, N, C = patch_embed.size()
|
||||||
if self.legacy:
|
if self.legacy:
|
||||||
patch_embed = torch.cat([
|
patch_embed = torch.cat([
|
||||||
patch_embed[:, 0:1], patch_embed[:, 1:] + \
|
patch_embed[:, 0:1],
|
||||||
self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),
|
patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),
|
||||||
], dim=1)
|
], dim=1)
|
||||||
else:
|
else:
|
||||||
patch_embed = torch.cat([
|
patch_embed = torch.cat([
|
||||||
patch_embed[:, 0:1], patch_embed[:, 1:] + \
|
patch_embed[:, 0:1],
|
||||||
self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, N - 1, -1)))),
|
patch_embed[:, 1:] + self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, N - 1, -1)))),
|
||||||
], dim=1)
|
], dim=1)
|
||||||
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
|
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
|
||||||
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
|
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
|
||||||
@ -152,7 +153,16 @@ class Block(nn.Module):
|
|||||||
class PixelEmbed(nn.Module):
|
class PixelEmbed(nn.Module):
|
||||||
""" Image to Pixel Embedding
|
""" Image to Pixel Embedding
|
||||||
"""
|
"""
|
||||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4, legacy=False):
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
img_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
in_chans=3,
|
||||||
|
in_dim=48,
|
||||||
|
stride=4,
|
||||||
|
legacy=False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
img_size = to_2tuple(img_size)
|
img_size = to_2tuple(img_size)
|
||||||
patch_size = to_2tuple(patch_size)
|
patch_size = to_2tuple(patch_size)
|
||||||
@ -184,14 +194,17 @@ class PixelEmbed(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor, pixel_pos: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, pixel_pos: torch.Tensor) -> torch.Tensor:
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
_assert(H == self.img_size[0],
|
_assert(
|
||||||
|
H == self.img_size[0],
|
||||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||||
_assert(W == self.img_size[1],
|
_assert(
|
||||||
|
W == self.img_size[1],
|
||||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||||
if self.legacy:
|
if self.legacy:
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.unfold(x)
|
x = self.unfold(x)
|
||||||
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
|
x = x.transpose(1, 2).reshape(
|
||||||
|
B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
|
||||||
else:
|
else:
|
||||||
x = self.unfold(x)
|
x = self.unfold(x)
|
||||||
x = x.transpose(1, 2).reshape(B * self.num_patches, C, self.patch_size[0], self.patch_size[1])
|
x = x.transpose(1, 2).reshape(B * self.num_patches, C, self.patch_size[0], self.patch_size[1])
|
||||||
@ -204,6 +217,7 @@ class PixelEmbed(nn.Module):
|
|||||||
class TNT(nn.Module):
|
class TNT(nn.Module):
|
||||||
""" Transformer in Transformer - https://arxiv.org/abs/2103.00112
|
""" Transformer in Transformer - https://arxiv.org/abs/2103.00112
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
img_size=224,
|
img_size=224,
|
||||||
@ -458,42 +472,47 @@ def _cfg(url='', **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
default_cfgs = generate_default_cfgs({
|
default_cfgs = generate_default_cfgs({
|
||||||
|
'tnt_s_legacy_patch16_224.in1k': _cfg(
|
||||||
|
hf_hub_id='timm/',
|
||||||
|
#url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
|
||||||
|
),
|
||||||
'tnt_s_patch16_224.in1k': _cfg(
|
'tnt_s_patch16_224.in1k': _cfg(
|
||||||
# hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
# url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
|
#url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
|
||||||
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
|
|
||||||
),
|
),
|
||||||
'tnt_b_patch16_224.in1k': _cfg(
|
'tnt_b_patch16_224.in1k': _cfg(
|
||||||
# hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
|
#url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_filter_fn(state_dict, model):
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
state_dict.pop('outer_tokens', None)
|
state_dict.pop('outer_tokens', None)
|
||||||
|
if 'patch_pos' in state_dict:
|
||||||
out_dict = {}
|
out_dict = state_dict
|
||||||
for k, v in state_dict.items():
|
else:
|
||||||
k = k.replace('outer_pos', 'patch_pos')
|
out_dict = {}
|
||||||
k = k.replace('inner_pos', 'pixel_pos')
|
for k, v in state_dict.items():
|
||||||
k = k.replace('patch_embed', 'pixel_embed')
|
k = k.replace('outer_pos', 'patch_pos')
|
||||||
k = k.replace('proj_norm1', 'norm1_proj')
|
k = k.replace('inner_pos', 'pixel_pos')
|
||||||
k = k.replace('proj_norm2', 'norm2_proj')
|
k = k.replace('patch_embed', 'pixel_embed')
|
||||||
k = k.replace('inner_norm1', 'norm_in')
|
k = k.replace('proj_norm1', 'norm1_proj')
|
||||||
k = k.replace('inner_attn', 'attn_in')
|
k = k.replace('proj_norm2', 'norm2_proj')
|
||||||
k = k.replace('inner_norm2', 'norm_mlp_in')
|
k = k.replace('inner_norm1', 'norm_in')
|
||||||
k = k.replace('inner_mlp', 'mlp_in')
|
k = k.replace('inner_attn', 'attn_in')
|
||||||
k = k.replace('outer_norm1', 'norm_out')
|
k = k.replace('inner_norm2', 'norm_mlp_in')
|
||||||
k = k.replace('outer_attn', 'attn_out')
|
k = k.replace('inner_mlp', 'mlp_in')
|
||||||
k = k.replace('outer_norm2', 'norm_mlp')
|
k = k.replace('outer_norm1', 'norm_out')
|
||||||
k = k.replace('outer_mlp', 'mlp')
|
k = k.replace('outer_attn', 'attn_out')
|
||||||
if k == 'pixel_pos' and model.pixel_embed.legacy == False:
|
k = k.replace('outer_norm2', 'norm_mlp')
|
||||||
B, N, C = v.shape
|
k = k.replace('outer_mlp', 'mlp')
|
||||||
H = W = int(N ** 0.5)
|
if k == 'pixel_pos' and model.pixel_embed.legacy == False:
|
||||||
assert H * W == N
|
B, N, C = v.shape
|
||||||
v = v.permute(0, 2, 1).reshape(B, C, H, W)
|
H = W = int(N ** 0.5)
|
||||||
out_dict[k] = v
|
assert H * W == N
|
||||||
|
v = v.permute(0, 2, 1).reshape(B, C, H, W)
|
||||||
|
out_dict[k] = v
|
||||||
|
|
||||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||||
if out_dict['patch_pos'].shape != model.patch_pos.shape:
|
if out_dict['patch_pos'].shape != model.patch_pos.shape:
|
||||||
@ -515,6 +534,15 @@ def _create_tnt(variant, pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def tnt_s_legacy_patch16_224(pretrained=False, **kwargs) -> TNT:
|
||||||
|
model_cfg = dict(
|
||||||
|
patch_size=16, embed_dim=384, inner_dim=24, depth=12, num_heads_outer=6,
|
||||||
|
qkv_bias=False, legacy=True)
|
||||||
|
model = _create_tnt('tnt_s_legacy_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT:
|
def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT:
|
||||||
model_cfg = dict(
|
model_cfg = dict(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user