From b37f0f7a76ae6f6dc79d5eaf49cbdf61cc1fb2b7 Mon Sep 17 00:00:00 2001 From: Ryan <23580140+brianhou0208@users.noreply.github.com> Date: Fri, 2 May 2025 20:34:31 +0800 Subject: [PATCH] Update tnt.py --- timm/models/tnt.py | 131 ++++++++++++++++++++++++++++++++------------- 1 file changed, 93 insertions(+), 38 deletions(-) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index d97cfaae..a48fa5ac 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -5,6 +5,9 @@ A PyTorch implement of TNT as described in The official mindspore code is released and available at https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT + +The official pytorch code is released and available at +https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch """ import math from typing import Optional @@ -12,7 +15,7 @@ from typing import Optional import torch import torch.nn as nn -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed from ._builder import build_model_with_cfg from ._manipulate import checkpoint @@ -22,28 +25,6 @@ from ._registry import register_model __all__ = ['TNT'] # model_registry will add each entrypoint fn to this -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'pixel_embed.proj', 'classifier': 'head', - **kwargs - } - - -default_cfgs = { - 'tnt_s_patch16_224': _cfg( - url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - ), - 'tnt_b_patch16_224': _cfg( - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - ), -} - - class Attention(nn.Module): """ Multi-Head Attention """ @@ -94,6 +75,7 @@ class Block(nn.Module): drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, + legacy=False, ): super().__init__() # Inner transformer @@ -115,9 +97,14 @@ class Block(nn.Module): act_layer=act_layer, drop=proj_drop, ) - - self.norm1_proj = norm_layer(dim) - self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True) + self.legacy = legacy + if self.legacy: + self.norm1_proj = norm_layer(dim) + self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True) + else: + self.norm1_proj = norm_layer(dim * num_pixel) + self.proj = nn.Linear(dim * num_pixel, dim_out, bias=False) + self.norm2_proj = norm_layer(dim_out) # Outer transformer self.norm_out = norm_layer(dim_out) @@ -146,9 +133,16 @@ class Block(nn.Module): pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) # outer B, N, C = patch_embed.size() - patch_embed = torch.cat( - [patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))], - dim=1) + if self.legacy: + patch_embed = torch.cat([ + patch_embed[:, 0:1], patch_embed[:, 1:] + \ + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)), + ], dim=1) + else: + patch_embed = torch.cat([ + patch_embed[:, 0:1], patch_embed[:, 1:] + \ + self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, N - 1, -1)))), + ], dim=1) patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed))) patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) return pixel_embed, patch_embed @@ -157,7 +151,7 @@ class Block(nn.Module): class PixelEmbed(nn.Module): """ Image to Pixel Embedding """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4): + def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4, legacy=False): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -165,13 +159,18 @@ class PixelEmbed(nn.Module): self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) num_patches = (self.grid_size[0]) * (self.grid_size[1]) self.img_size = img_size + self.patch_size = patch_size + self.legacy = legacy self.num_patches = num_patches self.in_dim = in_dim new_patch_size = [math.ceil(ps / stride) for ps in patch_size] self.new_patch_size = new_patch_size self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride) - self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size) + if self.legacy: + self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size) + else: + self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size) def forward(self, x, pixel_pos): B, C, H, W = x.shape @@ -179,9 +178,14 @@ class PixelEmbed(nn.Module): f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") _assert(W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") - x = self.proj(x) - x = self.unfold(x) - x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) + if self.legacy: + x = self.proj(x) + x = self.unfold(x) + x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) + else: + x = self.unfold(x) + x = x.transpose(1, 2).reshape(B * self.num_patches, C, self.patch_size[0], self.patch_size[1]) + x = self.proj(x) x = x + pixel_pos x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2) return x @@ -211,6 +215,7 @@ class TNT(nn.Module): drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4, + legacy=False, ): super().__init__() assert global_pool in ('', 'token', 'avg') @@ -225,6 +230,7 @@ class TNT(nn.Module): in_chans=in_chans, in_dim=inner_dim, stride=first_stride, + legacy=legacy, ) num_patches = self.pixel_embed.num_patches self.num_patches = num_patches @@ -255,6 +261,7 @@ class TNT(nn.Module): attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + legacy=legacy, )) self.blocks = nn.ModuleList(blocks) self.norm = norm_layer(embed_dim) @@ -338,14 +345,38 @@ class TNT(nn.Module): def checkpoint_filter_fn(state_dict, model): + state_dict.pop('outer_tokens', None) + + out_dict = {} + for k, v in state_dict.items(): + k = k.replace('outer_pos', 'patch_pos') + k = k.replace('inner_pos', 'pixel_pos') + k = k.replace('patch_embed', 'pixel_embed') + k = k.replace('proj_norm1', 'norm1_proj') + k = k.replace('proj_norm2', 'norm2_proj') + k = k.replace('inner_norm1', 'norm_in') + k = k.replace('inner_attn', 'attn_in') + k = k.replace('inner_norm2', 'norm_mlp_in') + k = k.replace('inner_mlp', 'mlp_in') + k = k.replace('outer_norm1', 'norm_out') + k = k.replace('outer_attn', 'attn_out') + k = k.replace('outer_norm2', 'norm_mlp') + k = k.replace('outer_mlp', 'mlp') + if k == 'pixel_pos': + B, N, C = v.shape + H = W = int(N ** 0.5) + assert H * W == N + v = v.permute(0, 2, 1).reshape(B, C, H, W) + out_dict[k] = v + """ convert patch embedding weight from manual patchify + linear proj to conv""" - if state_dict['patch_pos'].shape != model.patch_pos.shape: - state_dict['patch_pos'] = resample_abs_pos_embed( - state_dict['patch_pos'], + if out_dict['patch_pos'].shape != model.patch_pos.shape: + out_dict['patch_pos'] = resample_abs_pos_embed( + out_dict['patch_pos'], new_size=model.pixel_embed.grid_size, num_prefix_tokens=1, ) - return state_dict + return out_dict def _create_tnt(variant, pretrained=False, **kwargs): @@ -359,6 +390,30 @@ def _create_tnt(variant, pretrained=False, **kwargs): return model +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'pixel_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'tnt_s_patch16_224': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar', + url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar', + ), + 'tnt_b_patch16_224': _cfg( + # hf_hub_id='timm/', + url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar', + ), +} + + @register_model def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT: model_cfg = dict(