[Feature] Add transformer in transformer (#339)
* add tnt_small configs * add tnt backbone * test tnt * add tnt to model_zoo * rename the config file name * add optimizor * move tnt backbone unitest * add metric * fix keyname in arch * encapsulate "inner transformer block" and "outer transformer block" * fix TnT * Use `inner_block_cfg` and `outer_block_cfg` instead of `args` and `kwargs`. Co-authored-by: mzr1996 <mzr1996@163.com>pull/382/head
parent
db856df43e
commit
359f56ad58
|
@ -0,0 +1,29 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='TNT',
|
||||
arch='s',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
ffn_ratio=4,
|
||||
qkv_bias=False,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
first_stride=4,
|
||||
num_fcs=2,
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
]),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=384,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
topk=(1, 5),
|
||||
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)))
|
|
@ -0,0 +1,39 @@
|
|||
# accuracy_top-1 : 81.52 accuracy_top-5 : 95.73
|
||||
_base_ = [
|
||||
'../_base_/models/tnt_s_patch16_224.py',
|
||||
'../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
img_norm_cfg = dict(
|
||||
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
size=(248, -1),
|
||||
interpolation='bicubic',
|
||||
backend='pillow'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
|
||||
dataset_type = 'ImageNet'
|
||||
data = dict(
|
||||
samples_per_gpu=32, workers_per_gpu=4, test=dict(pipeline=test_pipeline))
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='AdamW', lr=1e-3, weight_decay=0.05)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
|
||||
lr_config = dict(
|
||||
policy='CosineAnnealing',
|
||||
min_lr=0,
|
||||
warmup_by_epoch=True,
|
||||
warmup='linear',
|
||||
warmup_iters=5,
|
||||
warmup_ratio=1e-3)
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=300)
|
|
@ -43,6 +43,7 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
|||
| Swin-Transformer tiny | 28.29 | 4.36 | 81.18 | 95.61 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_tiny_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925.log.json)|
|
||||
| Swin-Transformer small| 49.61 | 8.52 | 83.02 | 96.29 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_small_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219.log.json)|
|
||||
| Swin-Transformer base | 87.77 | 15.14 | 83.36 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742-93230b0d.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742.log.json)|
|
||||
| Transformer in Transformer small* | 23.76 | 3.36 | 81.52 | 95.73 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/tnt/tnt_s_patch16_224_evalonly_imagenet) | [model](http://download.openmmlab.com/mmclassification/v0/transformer-in-transformer/convert/tnt_s_patch16_224_evalonly_imagenet.pth) | [log]()|
|
||||
|
||||
Models with * are converted from other repos, others are trained by ourselves.
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ from .seresnext import SEResNeXt
|
|||
from .shufflenet_v1 import ShuffleNetV1
|
||||
from .shufflenet_v2 import ShuffleNetV2
|
||||
from .swin_transformer import SwinTransformer
|
||||
from .tnt import TNT
|
||||
from .vgg import VGG
|
||||
from .vision_transformer import VisionTransformer
|
||||
|
||||
|
@ -19,5 +20,5 @@ __all__ = [
|
|||
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
|
||||
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
||||
'SwinTransformer'
|
||||
'SwinTransformer', 'TNT'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,366 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import to_2tuple
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
class TransformerBlock(BaseModule):
|
||||
"""Implement a transformer block in TnTLayer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension
|
||||
num_heads (int): Parallel attention heads
|
||||
ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer.
|
||||
Default: 4
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Default 0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs. Default 2
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default False
|
||||
act_cfg (dict): The activation config for FFNs. Defalut GELU
|
||||
norm_cfg (dict): Config dict for normalization layer. Default
|
||||
layer normalization
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim) or (n, batch, embed_dim).
|
||||
(batch, n, embed_dim) is common case in CV. Default to False
|
||||
init_cfg (dict, optional): Initialization config dict. Default to None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
ffn_ratio=4,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
qkv_bias=False,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
batch_first=True,
|
||||
init_cfg=None):
|
||||
super(TransformerBlock, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.norm_attn = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.attn = MultiheadAttention(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
batch_first=batch_first)
|
||||
|
||||
self.norm_ffn = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=embed_dims * ffn_ratio,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg)
|
||||
|
||||
if not qkv_bias:
|
||||
self.attn.attn.in_proj_bias = None
|
||||
|
||||
def forward(self, x):
|
||||
x = self.attn(self.norm_attn(x), identity=x)
|
||||
x = self.ffn(self.norm_ffn(x), identity=x)
|
||||
return x
|
||||
|
||||
|
||||
class TnTLayer(BaseModule):
|
||||
"""Implement one encoder layer in Transformer in Transformer.
|
||||
|
||||
Args:
|
||||
num_pixel (int): The pixel number in target patch transformed with
|
||||
a linear projection in inner transformer
|
||||
embed_dims_inner (int): Feature dimension in inner transformer block
|
||||
embed_dims_outer (int): Feature dimension in outer transformer block
|
||||
num_heads_inner (int): Parallel attention heads in inner transformer.
|
||||
num_heads_outer (int): Parallel attention heads in outer transformer.
|
||||
inner_block_cfg (dict): Extra config of inner transformer block.
|
||||
Defaults to empty dict.
|
||||
outer_block_cfg (dict): Extra config of outer transformer block.
|
||||
Defaults to empty dict.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default
|
||||
layer normalization
|
||||
init_cfg (dict, optional): Initialization config dict. Default to None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_pixel,
|
||||
embed_dims_inner,
|
||||
embed_dims_outer,
|
||||
num_heads_inner,
|
||||
num_heads_outer,
|
||||
inner_block_cfg=dict(),
|
||||
outer_block_cfg=dict(),
|
||||
norm_cfg=dict(type='LN'),
|
||||
init_cfg=None):
|
||||
super(TnTLayer, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.inner_block = TransformerBlock(
|
||||
embed_dims=embed_dims_inner,
|
||||
num_heads=num_heads_inner,
|
||||
**inner_block_cfg)
|
||||
|
||||
self.norm_proj = build_norm_layer(norm_cfg, embed_dims_inner)[1]
|
||||
self.projection = nn.Linear(
|
||||
embed_dims_inner * num_pixel, embed_dims_outer, bias=True)
|
||||
|
||||
self.outer_block = TransformerBlock(
|
||||
embed_dims=embed_dims_outer,
|
||||
num_heads=num_heads_outer,
|
||||
**outer_block_cfg)
|
||||
|
||||
def forward(self, pixel_embed, patch_embed):
|
||||
pixel_embed = self.inner_block(pixel_embed)
|
||||
|
||||
B, N, C = patch_embed.size()
|
||||
patch_embed[:, 1:] = patch_embed[:, 1:] + self.projection(
|
||||
self.norm_proj(pixel_embed).reshape(B, N - 1, -1))
|
||||
patch_embed = self.outer_block(patch_embed)
|
||||
|
||||
return pixel_embed, patch_embed
|
||||
|
||||
|
||||
class PixelEmbed(BaseModule):
|
||||
"""Image to Pixel Embedding.
|
||||
|
||||
Args:
|
||||
img_size (int | tuple): The size of input image
|
||||
patch_size (int): The size of one patch
|
||||
in_channels (int): The num of input channels
|
||||
embed_dims_inner (int): The num of channels of the target patch
|
||||
transformed with a linear projection in inner transformer
|
||||
stride (int): The stride of the conv2d layer. We use a conv2d layer
|
||||
and a unfold layer to implement image to pixel embedding.
|
||||
init_cfg (dict, optional): Initialization config dict
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dims_inner=48,
|
||||
stride=4,
|
||||
init_cfg=None):
|
||||
super(PixelEmbed, self).__init__(init_cfg=init_cfg)
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
# patches_resolution property necessary for resizing
|
||||
# positional embedding
|
||||
patches_resolution = [
|
||||
img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
||||
]
|
||||
num_patches = patches_resolution[0] * patches_resolution[1]
|
||||
|
||||
self.img_size = img_size
|
||||
self.num_patches = num_patches
|
||||
self.embed_dims_inner = embed_dims_inner
|
||||
|
||||
new_patch_size = [math.ceil(ps / stride) for ps in patch_size]
|
||||
self.new_patch_size = new_patch_size
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels,
|
||||
self.embed_dims_inner,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
stride=stride)
|
||||
self.unfold = nn.Unfold(
|
||||
kernel_size=new_patch_size, stride=new_patch_size)
|
||||
|
||||
def forward(self, x, pixel_pos):
|
||||
B, C, H, W = x.shape
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model " \
|
||||
f'({self.img_size[0]}*{self.img_size[1]}).'
|
||||
x = self.proj(x)
|
||||
x = self.unfold(x)
|
||||
x = x.transpose(1,
|
||||
2).reshape(B * self.num_patches, self.embed_dims_inner,
|
||||
self.new_patch_size[0],
|
||||
self.new_patch_size[1])
|
||||
x = x + pixel_pos
|
||||
x = x.reshape(B * self.num_patches, self.embed_dims_inner,
|
||||
-1).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class TNT(BaseBackbone):
|
||||
""" Transformer in Transformer
|
||||
A PyTorch implement of : `Transformer in Transformer
|
||||
<https://arxiv.org/abs/2103.00112>`_
|
||||
|
||||
Inspiration from
|
||||
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/tnt.py
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture
|
||||
Default: 'b'
|
||||
img_size (int | tuple): Input image size. Default to 224
|
||||
patch_size (int | tuple): The patch size. Deault to 16
|
||||
in_channels (int): Number of input channels. Default to 3
|
||||
ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer.
|
||||
Default: 4
|
||||
qkv_bias (bool): Enable bias for qkv if True. Default False
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Default 0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Default 0.
|
||||
drop_path_rate (float): stochastic depth rate. Default 0.
|
||||
act_cfg (dict): The activation config for FFNs. Defalut GELU
|
||||
norm_cfg (dict): Config dict for normalization layer. Default
|
||||
layer normalization
|
||||
first_stride (int): The stride of the conv2d layer. We use a conv2d
|
||||
layer and a unfold layer to implement image to pixel embedding.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs. Default 2
|
||||
init_cfg (dict, optional): Initialization config dict
|
||||
"""
|
||||
arch_zoo = {
|
||||
**dict.fromkeys(
|
||||
['s', 'small'], {
|
||||
'embed_dims_outer': 384,
|
||||
'embed_dims_inner': 24,
|
||||
'num_layers': 12,
|
||||
'num_heads_outer': 6,
|
||||
'num_heads_inner': 4
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['b', 'base'], {
|
||||
'embed_dims_outer': 640,
|
||||
'embed_dims_inner': 40,
|
||||
'num_layers': 12,
|
||||
'num_heads_outer': 10,
|
||||
'num_heads_inner': 4
|
||||
})
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
arch='b',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
ffn_ratio=4,
|
||||
qkv_bias=False,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
first_stride=4,
|
||||
num_fcs=2,
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
]):
|
||||
super(TNT, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
arch = arch.lower()
|
||||
assert arch in set(self.arch_zoo), \
|
||||
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
||||
self.arch_settings = self.arch_zoo[arch]
|
||||
else:
|
||||
essential_keys = {
|
||||
'embed_dims_outer', 'embed_dims_inner', 'num_layers',
|
||||
'num_heads_inner', 'num_heads_outer'
|
||||
}
|
||||
assert isinstance(arch, dict) and set(arch) == essential_keys, \
|
||||
f'Custom arch needs a dict with keys {essential_keys}'
|
||||
self.arch_settings = arch
|
||||
|
||||
self.embed_dims_inner = self.arch_settings['embed_dims_inner']
|
||||
self.embed_dims_outer = self.arch_settings['embed_dims_outer']
|
||||
# embed_dims for consistency with other models
|
||||
self.embed_dims = self.embed_dims_outer
|
||||
self.num_layers = self.arch_settings['num_layers']
|
||||
self.num_heads_inner = self.arch_settings['num_heads_inner']
|
||||
self.num_heads_outer = self.arch_settings['num_heads_outer']
|
||||
|
||||
self.pixel_embed = PixelEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dims_inner=self.embed_dims_inner,
|
||||
stride=first_stride)
|
||||
num_patches = self.pixel_embed.num_patches
|
||||
self.num_patches = num_patches
|
||||
new_patch_size = self.pixel_embed.new_patch_size
|
||||
num_pixel = new_patch_size[0] * new_patch_size[1]
|
||||
|
||||
self.norm1_proj = build_norm_layer(norm_cfg, num_pixel *
|
||||
self.embed_dims_inner)[1]
|
||||
self.projection = nn.Linear(num_pixel * self.embed_dims_inner,
|
||||
self.embed_dims_outer)
|
||||
self.norm2_proj = build_norm_layer(norm_cfg, self.embed_dims_outer)[1]
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims_outer))
|
||||
self.patch_pos = nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1, self.embed_dims_outer))
|
||||
self.pixel_pos = nn.Parameter(
|
||||
torch.zeros(1, self.embed_dims_inner, new_patch_size[0],
|
||||
new_patch_size[1]))
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, drop_path_rate, self.num_layers)
|
||||
] # stochastic depth decay rule
|
||||
self.layers = ModuleList()
|
||||
for i in range(self.num_layers):
|
||||
block_cfg = dict(
|
||||
ffn_ratio=ffn_ratio,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_cfg=norm_cfg,
|
||||
batch_first=True)
|
||||
self.layers.append(
|
||||
TnTLayer(
|
||||
num_pixel=num_pixel,
|
||||
embed_dims_inner=self.embed_dims_inner,
|
||||
embed_dims_outer=self.embed_dims_outer,
|
||||
num_heads_inner=self.num_heads_inner,
|
||||
num_heads_outer=self.num_heads_outer,
|
||||
inner_block_cfg=block_cfg,
|
||||
outer_block_cfg=block_cfg,
|
||||
norm_cfg=norm_cfg))
|
||||
|
||||
self.norm = build_norm_layer(norm_cfg, self.embed_dims_outer)[1]
|
||||
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
trunc_normal_(self.patch_pos, std=.02)
|
||||
trunc_normal_(self.pixel_pos, std=.02)
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
pixel_embed = self.pixel_embed(x, self.pixel_pos)
|
||||
|
||||
patch_embed = self.norm2_proj(
|
||||
self.projection(
|
||||
self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
|
||||
patch_embed = torch.cat(
|
||||
(self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
|
||||
patch_embed = patch_embed + self.patch_pos
|
||||
patch_embed = self.drop_after_pos(patch_embed)
|
||||
|
||||
for layer in self.layers:
|
||||
pixel_embed, patch_embed = layer(pixel_embed, patch_embed)
|
||||
|
||||
patch_embed = self.norm(patch_embed)
|
||||
return patch_embed[:, 0]
|
|
@ -0,0 +1,47 @@
|
|||
import pytest
|
||||
import torch
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import TNT
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_tnt_backbone():
|
||||
with pytest.raises(TypeError):
|
||||
# pretrained must be a string path
|
||||
model = TNT()
|
||||
model.init_weights(pretrained=0)
|
||||
|
||||
# Test tnt_base_patch16_224
|
||||
model = TNT()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), True)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat.shape == torch.Size((1, 640))
|
||||
|
||||
# Test tnt with embed_dims=768
|
||||
arch = {
|
||||
'embed_dims_outer': 768,
|
||||
'embed_dims_inner': 48,
|
||||
'num_layers': 12,
|
||||
'num_heads_outer': 6,
|
||||
'num_heads_inner': 4
|
||||
}
|
||||
model = TNT(arch=arch)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat.shape == torch.Size((1, 768))
|
Loading…
Reference in New Issue