mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update tnt.py
This commit is contained in:
parent
c8c4f256b8
commit
b37f0f7a76
@ -5,6 +5,9 @@ A PyTorch implement of TNT as described in
|
|||||||
|
|
||||||
The official mindspore code is released and available at
|
The official mindspore code is released and available at
|
||||||
https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
|
https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
|
||||||
|
|
||||||
|
The official pytorch code is released and available at
|
||||||
|
https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -12,7 +15,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed
|
from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import checkpoint
|
from ._manipulate import checkpoint
|
||||||
@ -22,28 +25,6 @@ from ._registry import register_model
|
|||||||
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this
|
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
|
||||||
return {
|
|
||||||
'url': url,
|
|
||||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
||||||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
||||||
'first_conv': 'pixel_embed.proj', 'classifier': 'head',
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
default_cfgs = {
|
|
||||||
'tnt_s_patch16_224': _cfg(
|
|
||||||
url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
|
|
||||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
||||||
),
|
|
||||||
'tnt_b_patch16_224': _cfg(
|
|
||||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
""" Multi-Head Attention
|
""" Multi-Head Attention
|
||||||
"""
|
"""
|
||||||
@ -94,6 +75,7 @@ class Block(nn.Module):
|
|||||||
drop_path=0.,
|
drop_path=0.,
|
||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
norm_layer=nn.LayerNorm,
|
norm_layer=nn.LayerNorm,
|
||||||
|
legacy=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Inner transformer
|
# Inner transformer
|
||||||
@ -115,9 +97,14 @@ class Block(nn.Module):
|
|||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
drop=proj_drop,
|
drop=proj_drop,
|
||||||
)
|
)
|
||||||
|
self.legacy = legacy
|
||||||
self.norm1_proj = norm_layer(dim)
|
if self.legacy:
|
||||||
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True)
|
self.norm1_proj = norm_layer(dim)
|
||||||
|
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True)
|
||||||
|
else:
|
||||||
|
self.norm1_proj = norm_layer(dim * num_pixel)
|
||||||
|
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=False)
|
||||||
|
self.norm2_proj = norm_layer(dim_out)
|
||||||
|
|
||||||
# Outer transformer
|
# Outer transformer
|
||||||
self.norm_out = norm_layer(dim_out)
|
self.norm_out = norm_layer(dim_out)
|
||||||
@ -146,9 +133,16 @@ class Block(nn.Module):
|
|||||||
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
|
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
|
||||||
# outer
|
# outer
|
||||||
B, N, C = patch_embed.size()
|
B, N, C = patch_embed.size()
|
||||||
patch_embed = torch.cat(
|
if self.legacy:
|
||||||
[patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))],
|
patch_embed = torch.cat([
|
||||||
dim=1)
|
patch_embed[:, 0:1], patch_embed[:, 1:] + \
|
||||||
|
self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),
|
||||||
|
], dim=1)
|
||||||
|
else:
|
||||||
|
patch_embed = torch.cat([
|
||||||
|
patch_embed[:, 0:1], patch_embed[:, 1:] + \
|
||||||
|
self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, N - 1, -1)))),
|
||||||
|
], dim=1)
|
||||||
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
|
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
|
||||||
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
|
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
|
||||||
return pixel_embed, patch_embed
|
return pixel_embed, patch_embed
|
||||||
@ -157,7 +151,7 @@ class Block(nn.Module):
|
|||||||
class PixelEmbed(nn.Module):
|
class PixelEmbed(nn.Module):
|
||||||
""" Image to Pixel Embedding
|
""" Image to Pixel Embedding
|
||||||
"""
|
"""
|
||||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4):
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4, legacy=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
img_size = to_2tuple(img_size)
|
img_size = to_2tuple(img_size)
|
||||||
patch_size = to_2tuple(patch_size)
|
patch_size = to_2tuple(patch_size)
|
||||||
@ -165,13 +159,18 @@ class PixelEmbed(nn.Module):
|
|||||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||||
num_patches = (self.grid_size[0]) * (self.grid_size[1])
|
num_patches = (self.grid_size[0]) * (self.grid_size[1])
|
||||||
self.img_size = img_size
|
self.img_size = img_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.legacy = legacy
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
self.in_dim = in_dim
|
self.in_dim = in_dim
|
||||||
new_patch_size = [math.ceil(ps / stride) for ps in patch_size]
|
new_patch_size = [math.ceil(ps / stride) for ps in patch_size]
|
||||||
self.new_patch_size = new_patch_size
|
self.new_patch_size = new_patch_size
|
||||||
|
|
||||||
self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride)
|
self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride)
|
||||||
self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size)
|
if self.legacy:
|
||||||
|
self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size)
|
||||||
|
else:
|
||||||
|
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
def forward(self, x, pixel_pos):
|
def forward(self, x, pixel_pos):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
@ -179,9 +178,14 @@ class PixelEmbed(nn.Module):
|
|||||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||||
_assert(W == self.img_size[1],
|
_assert(W == self.img_size[1],
|
||||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
||||||
x = self.proj(x)
|
if self.legacy:
|
||||||
x = self.unfold(x)
|
x = self.proj(x)
|
||||||
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
|
x = self.unfold(x)
|
||||||
|
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
|
||||||
|
else:
|
||||||
|
x = self.unfold(x)
|
||||||
|
x = x.transpose(1, 2).reshape(B * self.num_patches, C, self.patch_size[0], self.patch_size[1])
|
||||||
|
x = self.proj(x)
|
||||||
x = x + pixel_pos
|
x = x + pixel_pos
|
||||||
x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2)
|
x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2)
|
||||||
return x
|
return x
|
||||||
@ -211,6 +215,7 @@ class TNT(nn.Module):
|
|||||||
drop_path_rate=0.,
|
drop_path_rate=0.,
|
||||||
norm_layer=nn.LayerNorm,
|
norm_layer=nn.LayerNorm,
|
||||||
first_stride=4,
|
first_stride=4,
|
||||||
|
legacy=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert global_pool in ('', 'token', 'avg')
|
assert global_pool in ('', 'token', 'avg')
|
||||||
@ -225,6 +230,7 @@ class TNT(nn.Module):
|
|||||||
in_chans=in_chans,
|
in_chans=in_chans,
|
||||||
in_dim=inner_dim,
|
in_dim=inner_dim,
|
||||||
stride=first_stride,
|
stride=first_stride,
|
||||||
|
legacy=legacy,
|
||||||
)
|
)
|
||||||
num_patches = self.pixel_embed.num_patches
|
num_patches = self.pixel_embed.num_patches
|
||||||
self.num_patches = num_patches
|
self.num_patches = num_patches
|
||||||
@ -255,6 +261,7 @@ class TNT(nn.Module):
|
|||||||
attn_drop=attn_drop_rate,
|
attn_drop=attn_drop_rate,
|
||||||
drop_path=dpr[i],
|
drop_path=dpr[i],
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
|
legacy=legacy,
|
||||||
))
|
))
|
||||||
self.blocks = nn.ModuleList(blocks)
|
self.blocks = nn.ModuleList(blocks)
|
||||||
self.norm = norm_layer(embed_dim)
|
self.norm = norm_layer(embed_dim)
|
||||||
@ -338,14 +345,38 @@ class TNT(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def checkpoint_filter_fn(state_dict, model):
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
|
state_dict.pop('outer_tokens', None)
|
||||||
|
|
||||||
|
out_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
k = k.replace('outer_pos', 'patch_pos')
|
||||||
|
k = k.replace('inner_pos', 'pixel_pos')
|
||||||
|
k = k.replace('patch_embed', 'pixel_embed')
|
||||||
|
k = k.replace('proj_norm1', 'norm1_proj')
|
||||||
|
k = k.replace('proj_norm2', 'norm2_proj')
|
||||||
|
k = k.replace('inner_norm1', 'norm_in')
|
||||||
|
k = k.replace('inner_attn', 'attn_in')
|
||||||
|
k = k.replace('inner_norm2', 'norm_mlp_in')
|
||||||
|
k = k.replace('inner_mlp', 'mlp_in')
|
||||||
|
k = k.replace('outer_norm1', 'norm_out')
|
||||||
|
k = k.replace('outer_attn', 'attn_out')
|
||||||
|
k = k.replace('outer_norm2', 'norm_mlp')
|
||||||
|
k = k.replace('outer_mlp', 'mlp')
|
||||||
|
if k == 'pixel_pos':
|
||||||
|
B, N, C = v.shape
|
||||||
|
H = W = int(N ** 0.5)
|
||||||
|
assert H * W == N
|
||||||
|
v = v.permute(0, 2, 1).reshape(B, C, H, W)
|
||||||
|
out_dict[k] = v
|
||||||
|
|
||||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||||
if state_dict['patch_pos'].shape != model.patch_pos.shape:
|
if out_dict['patch_pos'].shape != model.patch_pos.shape:
|
||||||
state_dict['patch_pos'] = resample_abs_pos_embed(
|
out_dict['patch_pos'] = resample_abs_pos_embed(
|
||||||
state_dict['patch_pos'],
|
out_dict['patch_pos'],
|
||||||
new_size=model.pixel_embed.grid_size,
|
new_size=model.pixel_embed.grid_size,
|
||||||
num_prefix_tokens=1,
|
num_prefix_tokens=1,
|
||||||
)
|
)
|
||||||
return state_dict
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
def _create_tnt(variant, pretrained=False, **kwargs):
|
def _create_tnt(variant, pretrained=False, **kwargs):
|
||||||
@ -359,6 +390,30 @@ def _create_tnt(variant, pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||||
|
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||||
|
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||||
|
'first_conv': 'pixel_embed.proj', 'classifier': 'head',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = {
|
||||||
|
'tnt_s_patch16_224': _cfg(
|
||||||
|
# hf_hub_id='timm/',
|
||||||
|
# url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
|
||||||
|
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
|
||||||
|
),
|
||||||
|
'tnt_b_patch16_224': _cfg(
|
||||||
|
# hf_hub_id='timm/',
|
||||||
|
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT:
|
def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT:
|
||||||
model_cfg = dict(
|
model_cfg = dict(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user