Updated tnt model weights on hub, add back legacy model in case bwd compat

This commit is contained in:
Ross Wightman 2025-05-14 08:40:53 -07:00
parent 69b1fbcdc1
commit 74ad32a67e

View File

@ -22,13 +22,13 @@ from ._features import feature_take_indices
from ._manipulate import checkpoint
from ._registry import generate_default_cfgs, register_model
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this
class Attention(nn.Module):
""" Multi-Head Attention
"""
def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.hidden_dim = hidden_dim
@ -62,6 +62,7 @@ class Attention(nn.Module):
class Block(nn.Module):
""" TNT Block
"""
def __init__(
self,
dim,
@ -136,13 +137,13 @@ class Block(nn.Module):
B, N, C = patch_embed.size()
if self.legacy:
patch_embed = torch.cat([
patch_embed[:, 0:1], patch_embed[:, 1:] + \
self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),
patch_embed[:, 0:1],
patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),
], dim=1)
else:
patch_embed = torch.cat([
patch_embed[:, 0:1], patch_embed[:, 1:] + \
self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, N - 1, -1)))),
patch_embed[:, 0:1],
patch_embed[:, 1:] + self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, N - 1, -1)))),
], dim=1)
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
@ -152,7 +153,16 @@ class Block(nn.Module):
class PixelEmbed(nn.Module):
""" Image to Pixel Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4, legacy=False):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
in_dim=48,
stride=4,
legacy=False,
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
@ -184,14 +194,17 @@ class PixelEmbed(nn.Module):
def forward(self, x: torch.Tensor, pixel_pos: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
_assert(H == self.img_size[0],
_assert(
H == self.img_size[0],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
_assert(W == self.img_size[1],
_assert(
W == self.img_size[1],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
if self.legacy:
x = self.proj(x)
x = self.unfold(x)
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
x = x.transpose(1, 2).reshape(
B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
else:
x = self.unfold(x)
x = x.transpose(1, 2).reshape(B * self.num_patches, C, self.patch_size[0], self.patch_size[1])
@ -204,6 +217,7 @@ class PixelEmbed(nn.Module):
class TNT(nn.Module):
""" Transformer in Transformer - https://arxiv.org/abs/2103.00112
"""
def __init__(
self,
img_size=224,
@ -458,21 +472,26 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({
'tnt_s_legacy_patch16_224.in1k': _cfg(
hf_hub_id='timm/',
#url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
),
'tnt_s_patch16_224.in1k': _cfg(
# hf_hub_id='timm/',
# url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
hf_hub_id='timm/',
#url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
),
'tnt_b_patch16_224.in1k': _cfg(
# hf_hub_id='timm/',
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
hf_hub_id='timm/',
#url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
),
})
def checkpoint_filter_fn(state_dict, model):
state_dict.pop('outer_tokens', None)
if 'patch_pos' in state_dict:
out_dict = state_dict
else:
out_dict = {}
for k, v in state_dict.items():
k = k.replace('outer_pos', 'patch_pos')
@ -515,6 +534,15 @@ def _create_tnt(variant, pretrained=False, **kwargs):
return model
@register_model
def tnt_s_legacy_patch16_224(pretrained=False, **kwargs) -> TNT:
model_cfg = dict(
patch_size=16, embed_dim=384, inner_dim=24, depth=12, num_heads_outer=6,
qkv_bias=False, legacy=True)
model = _create_tnt('tnt_s_legacy_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs))
return model
@register_model
def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT:
model_cfg = dict(