[Refactor] Use `out_type` to specify ViT-like backbone output. (#1408)

* [Refactor] Use  to specify ViT-like backbone output.

* Fix ClsBatchNormNeck

* Update mmpretrain/models/necks/mae_neck.py

---------

Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>
pull/1418/head
Ma Zerun 2023-03-09 11:02:58 +08:00 committed by GitHub
parent 63e5b512cc
commit dbf3df21a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
64 changed files with 497 additions and 604 deletions

View File

@ -72,7 +72,7 @@ val_dataloader = dict(
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
# ann_file='meta/val.txt',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),

View File

@ -5,9 +5,8 @@ model = dict(
arch='eva-g',
img_size=224,
patch_size=14,
avg_token=True,
layer_scale_init_value=0.0,
output_cls_token=False,
out_type='avg_featmap',
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,

View File

@ -5,9 +5,8 @@ model = dict(
arch='l',
img_size=224,
patch_size=14,
avg_token=True,
layer_scale_init_value=0.0,
output_cls_token=False,
out_type='avg_featmap',
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,

View File

@ -6,9 +6,7 @@ model = dict(
arch='deit-base',
img_size=224,
patch_size=16,
output_cls_token=False,
avg_token=True,
with_cls_token=False,
out_type='avg_featmap',
),
neck=None,
head=dict(

View File

@ -6,9 +6,7 @@ model = dict(
arch='deit-small',
img_size=224,
patch_size=16,
output_cls_token=False,
avg_token=True,
with_cls_token=False,
out_type='avg_featmap',
),
neck=None,
head=dict(

View File

@ -21,8 +21,7 @@ model = dict(
img_size=224,
patch_size=16,
drop_path_rate=0.1,
avg_token=True,
output_cls_token=False,
out_type='avg_featmap',
use_abs_pos_emb=False,
use_rel_pos_bias=True,
use_shared_rel_pos_bias=False),

View File

@ -20,8 +20,7 @@ model = dict(
arch='base',
img_size=224,
patch_size=16,
avg_token=True,
output_cls_token=False,
out_type='avg_featmap',
use_abs_pos_emb=False,
use_rel_pos_bias=True,
use_shared_rel_pos_bias=False,

View File

@ -14,10 +14,9 @@ vqkd_encoder = dict(
drop_path_rate=0.,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
out_type='featmap',
with_cls_token=True,
avg_token=False,
frozen_stages=-1,
output_cls_token=False,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,

View File

@ -14,10 +14,9 @@ vqkd_encoder = dict(
drop_path_rate=0.,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
out_type='featmap',
with_cls_token=True,
avg_token=False,
frozen_stages=-1,
output_cls_token=False,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,

View File

@ -14,8 +14,7 @@ model = dict(
patch_size=16,
# 0.2 for 1600 epochs pretrained models and 0.1 for 300 epochs.
drop_path_rate=0.1,
avg_token=True,
output_cls_token=False,
out_type='avg_featmap',
use_abs_pos_emb=False,
use_rel_pos_bias=True,
use_shared_rel_pos_bias=False),

View File

@ -11,8 +11,7 @@ model = dict(
arch='base',
img_size=224,
patch_size=16,
avg_token=True,
output_cls_token=False,
out_type='avg_featmap',
use_abs_pos_emb=False,
use_rel_pos_bias=True,
use_shared_rel_pos_bias=False,

View File

@ -67,11 +67,10 @@ model = dict(
arch='base',
img_size=224,
patch_size=16,
avg_token=True, # use average token for cls head
final_norm=False, # do not use final norm
drop_path_rate=0.1,
layer_scale_init_value=0.1,
output_cls_token=False,
out_type='avg_featmap',
use_abs_pos_emb=True,
use_rel_pos_bias=True,
use_shared_rel_pos_bias=False,

View File

@ -56,8 +56,7 @@ model = dict(
img_size=224,
patch_size=16,
drop_path_rate=0.1,
avg_token=True,
output_cls_token=False,
out_type='avg_featmap',
final_norm=False,
init_cfg=dict(type='Pretrained', checkpoint='')),
neck=None,

View File

@ -17,7 +17,7 @@ model = dict(
img_size=224,
patch_size=16,
frozen_stages=12,
avg_token=False,
out_type='cls_token',
final_norm=True,
init_cfg=dict(type='Pretrained', checkpoint='')),
neck=dict(type='ClsBatchNormNeck', input_features=768),

View File

@ -5,9 +5,8 @@ model = dict(
arch='eva-g',
img_size=224,
patch_size=14,
avg_token=True,
layer_scale_init_value=0.0,
output_cls_token=False,
out_type='avg_featmap',
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,

View File

@ -5,9 +5,8 @@ model = dict(
arch='eva-g',
img_size=224,
patch_size=16,
avg_token=True,
layer_scale_init_value=0.0,
output_cls_token=False,
out_type='avg_featmap',
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,

View File

@ -5,9 +5,8 @@ model = dict(
arch='l',
img_size=224,
patch_size=14,
avg_token=True,
layer_scale_init_value=0.0,
output_cls_token=False,
out_type='avg_featmap',
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,

View File

@ -55,8 +55,7 @@ model = dict(
img_size=224,
patch_size=16,
drop_path_rate=0.1,
avg_token=True,
output_cls_token=False,
out_type='avg_featmap',
final_norm=False,
init_cfg=dict(type='Pretrained', checkpoint='')),
neck=None,

View File

@ -18,7 +18,7 @@ model = dict(
img_size=224,
patch_size=16,
frozen_stages=12,
avg_token=False,
out_type='cls_token',
final_norm=True,
init_cfg=dict(type='Pretrained', checkpoint='')),
neck=dict(type='ClsBatchNormNeck', input_features=768),

View File

@ -57,8 +57,7 @@ model = dict(
img_size=448,
patch_size=14,
drop_path_rate=0.3, # set to 0.3
avg_token=True,
output_cls_token=False,
out_type='avg_featmap',
final_norm=False,
init_cfg=dict(type='Pretrained', checkpoint='')),
neck=None,

View File

@ -56,8 +56,7 @@ model = dict(
img_size=224,
patch_size=14,
drop_path_rate=0.3, # set to 0.3
avg_token=True,
output_cls_token=False,
out_type='avg_featmap',
final_norm=False,
init_cfg=dict(type='Pretrained', checkpoint='')),
neck=None,

View File

@ -56,8 +56,7 @@ model = dict(
img_size=224,
patch_size=16,
drop_path_rate=0.2, # set to 0.2
avg_token=True,
output_cls_token=False,
out_type='avg_featmap',
final_norm=False,
init_cfg=dict(type='Pretrained', checkpoint='')),
neck=None,

View File

@ -18,7 +18,7 @@ model = dict(
img_size=224,
patch_size=16,
frozen_stages=24,
avg_token=False,
out_type='cls_token',
final_norm=True,
init_cfg=dict(type='Pretrained', checkpoint='')),
neck=dict(type='ClsBatchNormNeck', input_features=1024),

View File

@ -54,8 +54,7 @@ model = dict(
img_size=224,
patch_size=16,
drop_path_rate=0.1,
avg_token=True,
output_cls_token=False,
out_type='avg_featmap',
final_norm=False,
init_cfg=dict(type='Pretrained', checkpoint='')),
neck=None,

View File

@ -56,8 +56,7 @@ model = dict(
img_size=224,
patch_size=16,
drop_path_rate=0.1,
avg_token=True,
output_cls_token=False,
out_type='avg_featmap',
final_norm=False,
init_cfg=dict(type='Pretrained', checkpoint='')),
neck=None,

View File

@ -17,7 +17,7 @@ model = dict(
img_size=224,
patch_size=16,
frozen_stages=12,
avg_token=False,
out_type='cls_token',
final_norm=True,
init_cfg=dict(type='Pretrained', checkpoint='')),
neck=dict(type='ClsBatchNormNeck', input_features=768),

View File

@ -19,8 +19,7 @@ model = dict(
with_last_bn=True,
with_last_bn_affine=False,
with_last_bias=False,
with_avg_pool=True,
vit_backbone=False),
with_avg_pool=True),
head=dict(
type='MoCoV3Head',
predictor=dict(

View File

@ -19,8 +19,7 @@ model = dict(
with_last_bn=True,
with_last_bn_affine=False,
with_last_bias=False,
with_avg_pool=True,
vit_backbone=False),
with_avg_pool=True),
head=dict(
type='MoCoV3Head',
predictor=dict(

View File

@ -19,8 +19,7 @@ model = dict(
with_last_bn=True,
with_last_bn_affine=False,
with_last_bias=False,
with_avg_pool=True,
vit_backbone=False),
with_avg_pool=True),
head=dict(
type='MoCoV3Head',
predictor=dict(

View File

@ -99,8 +99,7 @@ model = dict(
with_last_bn=True,
with_last_bn_affine=False,
with_last_bias=False,
with_avg_pool=False,
vit_backbone=True),
with_avg_pool=False),
head=dict(
type='MoCoV3Head',
predictor=dict(

View File

@ -99,8 +99,7 @@ model = dict(
with_last_bn=True,
with_last_bn_affine=False,
with_last_bias=False,
with_avg_pool=False,
vit_backbone=True),
with_avg_pool=False),
head=dict(
type='MoCoV3Head',
predictor=dict(

View File

@ -99,8 +99,7 @@ model = dict(
with_last_bn=True,
with_last_bn_affine=False,
with_last_bias=False,
with_avg_pool=False,
vit_backbone=True),
with_avg_pool=False),
head=dict(
type='MoCoV3Head',
predictor=dict(

View File

@ -25,7 +25,7 @@ Models:
Top 1 Accuracy: 79.87
Top 5 Accuracy: 94.90
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/revvit/revvit-base_3rdparty_in1k_20221213-87a7b0a5.pth
Weights: https://download.openmmlab.com/mmclassification/v0/revvit/revvit-small_3rdparty_in1k_20221213-a3a34f5c.pth
Config: configs/revvit/revvit-small_8xb256_in1k.py
Converted From:
Weights: https://dl.fbaipublicfiles.com/pyslowfast/rev/REV_VIT_S.pyth
@ -41,7 +41,7 @@ Models:
Top 1 Accuracy: 81.81
Top 5 Accuracy: 95.56
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/revvit/revvit-small_3rdparty_in1k_20221213-a3a34f5c.pth
Weights: https://download.openmmlab.com/mmclassification/v0/revvit/revvit-base_3rdparty_in1k_20221213-87a7b0a5.pth
Config: configs/revvit/revvit-base_8xb256_in1k.py
Converted From:
Weights: https://dl.fbaipublicfiles.com/pyslowfast/rev/REV_VIT_B.pyth

View File

@ -9,8 +9,7 @@ model = dict(
patch_size=16,
out_indices=-1,
drop_path_rate=0.1,
avg_token=False,
output_cls_token=False,
out_type='featmap',
final_norm=False),
neck=None,
head=dict(

View File

@ -4,13 +4,12 @@ from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmpretrain.registry import MODELS
from ..utils import (BEiTAttention, resize_pos_embed,
from ..utils import (BEiTAttention, build_norm_layer, resize_pos_embed,
resize_relative_position_bias_table, to_2tuple)
from .vision_transformer import TransformerEncoderLayer, VisionTransformer
@ -203,12 +202,12 @@ class BEiTTransformerEncoderLayer(TransformerEncoderLayer):
rel_pos_bias: torch.Tensor) -> torch.Tensor:
if self.gamma_1 is None:
x = x + self.drop_path(
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.ffn(self.norm2(x)))
self.attn(self.ln1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.ffn(self.ln2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(
self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x)))
self.ln1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.ffn(self.ln2(x)))
return x
@ -251,15 +250,21 @@ class BEiTViT(VisionTransformer):
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"avg_featmap"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
avg_token (bool): Whether or not to use the mean patch token for
classification. If True, the model will only take the average
of all patch tokens. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
use_abs_pos_emb (bool): Use position embedding like vanilla ViT.
Defaults to False.
use_rel_pos_bias (bool): Use relative position embedding in each
@ -289,10 +294,9 @@ class BEiTViT(VisionTransformer):
bias='qv_bias',
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=False,
out_type='avg_featmap',
with_cls_token=True,
avg_token=True,
frozen_stages=-1,
output_cls_token=False,
use_abs_pos_emb=False,
use_rel_pos_bias=True,
use_shared_rel_pos_bias=False,
@ -334,17 +338,25 @@ class BEiTViT(VisionTransformer):
self.patch_resolution = self.patch_embed.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set cls token
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
# Set out type
if out_type not in self.OUT_TYPES:
raise ValueError(f'Unsupported `out_type` {out_type}, please '
f'choose from {self.OUT_TYPES}')
self.out_type = out_type
self.interpolate_mode = interpolate_mode
# Set cls token
if with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
self.num_extra_tokens = 1
elif out_type != 'cls_token':
self.cls_token = None
self.num_extra_tokens = 0
else:
raise ValueError(
'with_cls_token must be True when `out_type="cls_token"`.')
# Set position embedding
self.interpolate_mode = interpolate_mode
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + self.num_extra_tokens,
@ -405,15 +417,10 @@ class BEiTViT(VisionTransformer):
self.frozen_stages = frozen_stages
self.final_norm = final_norm
if final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
self.avg_token = avg_token
if avg_token:
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
if out_type == 'avg_featmap':
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
# freeze stages only when self.frozen_stages > 0
if self.frozen_stages > 0:
@ -423,9 +430,10 @@ class BEiTViT(VisionTransformer):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.cls_token is not None:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_token, x), dim=1)
if self.pos_embed is not None:
x = x + resize_pos_embed(
@ -439,42 +447,32 @@ class BEiTViT(VisionTransformer):
rel_pos_bias = self.rel_pos_bias() \
if self.rel_pos_bias is not None else None
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
outs = []
for i, layer in enumerate(self.layers):
x = layer(x, rel_pos_bias)
if i == len(self.layers) - 1 and self.final_norm:
x = self.norm1(x)
x = self.ln1(x)
if i in self.out_indices:
B, _, C = x.shape
if self.with_cls_token:
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
else:
patch_token = x.reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
if self.avg_token:
patch_token = patch_token.permute(0, 2, 3, 1)
patch_token = patch_token.reshape(
B, patch_resolution[0] * patch_resolution[1],
C).mean(dim=1)
patch_token = self.norm2(patch_token)
if self.output_cls_token:
out = [patch_token, cls_token]
else:
out = patch_token
outs.append(out)
outs.append(self._format_output(x, patch_resolution))
return tuple(outs)
def _format_output(self, x, hw):
if self.out_type == 'raw':
return x
if self.out_type == 'cls_token':
return x[:, 0]
patch_token = x[:, self.num_extra_tokens:]
if self.out_type == 'featmap':
B = x.size(0)
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
if self.out_type == 'avg_featmap':
return self.ln2(patch_token.mean(dim=1))
def _prepare_relative_position_bias_table(self, state_dict, prefix, *args,
**kwargs):
from mmengine.logging import MMLogger

View File

@ -43,10 +43,18 @@ class DistilledVisionTransformer(VisionTransformer):
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: A tuple with the class token and the
distillation token. The shapes of both tensor are (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"cls_token"``.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
@ -55,11 +63,15 @@ class DistilledVisionTransformer(VisionTransformer):
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
num_extra_tokens = 2 # cls_token, dist_token
num_extra_tokens = 2 # class token and distillation token
def __init__(self, arch='deit-base', *args, **kwargs):
super(DistilledVisionTransformer, self).__init__(
arch=arch, *args, **kwargs)
arch=arch,
with_cls_token=True,
*args,
**kwargs,
)
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
def forward(self, x):
@ -78,37 +90,24 @@ class DistilledVisionTransformer(VisionTransformer):
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 2:]
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1 and self.final_norm:
x = self.norm1(x)
x = self.ln1(x)
if i in self.out_indices:
B, _, C = x.shape
if self.with_cls_token:
patch_token = x[:, 2:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
dist_token = x[:, 1]
else:
patch_token = x.reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
dist_token = None
if self.output_cls_token:
out = [patch_token, cls_token, dist_token]
else:
out = patch_token
outs.append(out)
outs.append(self._format_output(x, patch_resolution))
return tuple(outs)
def _format_output(self, x, hw):
if self.out_type == 'cls_token':
return x[:, 0], x[:, 1]
return super()._format_output(x, hw)
def init_weights(self):
super(DistilledVisionTransformer, self).init_weights()

View File

@ -3,7 +3,7 @@ from typing import Sequence
import numpy as np
import torch
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
from mmcv.cnn import Linear, build_activation_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import PatchEmbed
from mmengine.model import BaseModule, ModuleList, Sequential
@ -11,7 +11,8 @@ from mmengine.utils import deprecated_api_warning
from torch import nn
from mmpretrain.registry import MODELS
from ..utils import LayerScale, MultiheadAttention, resize_pos_embed, to_2tuple
from ..utils import (LayerScale, MultiheadAttention, build_norm_layer,
resize_pos_embed, to_2tuple)
from .vision_transformer import VisionTransformer
@ -149,9 +150,7 @@ class DeiT3TransformerEncoderLayer(BaseModule):
self.embed_dims = embed_dims
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
self.attn = MultiheadAttention(
embed_dims=embed_dims,
@ -162,9 +161,7 @@ class DeiT3TransformerEncoderLayer(BaseModule):
qkv_bias=qkv_bias,
use_layer_scale=use_layer_scale)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
self.ffn = DeiT3FFN(
embed_dims=embed_dims,
@ -175,14 +172,6 @@ class DeiT3TransformerEncoderLayer(BaseModule):
act_cfg=act_cfg,
use_layer_scale=use_layer_scale)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def init_weights(self):
super(DeiT3TransformerEncoderLayer, self).init_weights()
for m in self.ffn.modules():
@ -191,8 +180,8 @@ class DeiT3TransformerEncoderLayer(BaseModule):
nn.init.normal_(m.bias, std=1e-6)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = self.ffn(self.norm2(x), identity=x)
x = x + self.attn(self.ln1(x))
x = self.ffn(self.ln1(x), identity=x)
return x
@ -237,10 +226,19 @@ class DeiT3(VisionTransformer):
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"cls_token"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
use_layer_scale (bool): Whether to use layer_scale in DeiT3.
Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
@ -288,9 +286,7 @@ class DeiT3(VisionTransformer):
'feedforward_channels': 5120
}),
}
# not using num_extra_tokens in deit3 because adding cls tokens after
# adding pos_embed
num_extra_tokens = 0
num_extra_tokens = 1 # class token
def __init__(self,
arch='base',
@ -303,8 +299,8 @@ class DeiT3(VisionTransformer):
qkv_bias=True,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
out_type='cls_token',
with_cls_token=True,
output_cls_token=True,
use_layer_scale=True,
interpolate_mode='bicubic',
patch_cfg=dict(),
@ -343,13 +339,21 @@ class DeiT3(VisionTransformer):
self.patch_resolution = self.patch_embed.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set out type
if out_type not in self.OUT_TYPES:
raise ValueError(f'Unsupported `out_type` {out_type}, please '
f'choose from {self.OUT_TYPES}')
self.out_type = out_type
# Set cls token
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
if with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
elif out_type != 'cls_token':
self.cls_token = None
self.num_extra_tokens = 0
else:
raise ValueError(
'with_cls_token must be True when `out_type="cls_token"`.')
# Set position embedding
self.interpolate_mode = interpolate_mode
@ -393,9 +397,7 @@ class DeiT3(VisionTransformer):
self.final_norm = final_norm
if final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
def forward(self, x):
B = x.shape[0]
@ -406,38 +408,47 @@ class DeiT3(VisionTransformer):
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
num_extra_tokens=0)
x = self.drop_after_pos(x)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
if self.cls_token is not None:
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1 and self.final_norm:
x = self.norm1(x)
x = self.ln1(x)
if i in self.out_indices:
B, _, C = x.shape
if self.with_cls_token:
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
else:
patch_token = x.reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
if self.output_cls_token:
out = [patch_token, cls_token]
else:
out = patch_token
outs.append(out)
outs.append(self._format_output(x, patch_resolution))
return tuple(outs)
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1])))
pos_embed_shape = self.patch_embed.init_out_size
state_dict[name] = resize_pos_embed(
state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
num_extra_tokens=0, # The cls token adding is after pos_embed
)

View File

@ -196,7 +196,7 @@ class MixMIMBlock(TransformerEncoderLayer):
B, L, C = x.shape
shortcut = x
x = self.norm1(x)
x = self.ln1(x)
x = x.view(B, H, W, C)
# partition windows

View File

@ -1,10 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import sys
from typing import Sequence
import numpy as np
import torch
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
@ -14,7 +12,8 @@ from torch.autograd import Function as Function
from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mmpretrain.registry import MODELS
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed,
to_2tuple)
class RevBackProp(Function):
@ -152,9 +151,7 @@ class RevTransformerEncoderLayer(BaseModule):
self.drop_path_cfg = dict(type='DropPath', drop_prob=drop_path_rate)
self.embed_dims = embed_dims
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
self.attn = MultiheadAttention(
embed_dims=embed_dims,
@ -163,9 +160,7 @@ class RevTransformerEncoderLayer(BaseModule):
proj_drop=drop_rate,
qkv_bias=qkv_bias)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
self.ffn = FFN(
embed_dims=embed_dims,
@ -178,14 +173,6 @@ class RevTransformerEncoderLayer(BaseModule):
self.layer_id = layer_id
self.seeds = {}
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def init_weights(self):
super(RevTransformerEncoderLayer, self).init_weights()
for m in self.ffn.modules():
@ -215,13 +202,13 @@ class RevTransformerEncoderLayer(BaseModule):
Implementation of Reversible TransformerEncoderLayer
`
x = x + self.attn(self.norm1(x))
x = self.ffn(self.norm2(x), identity=x)
x = x + self.attn(self.ln1(x))
x = self.ffn(self.ln2(x), identity=x)
`
"""
self.seed_cuda('attn')
# attention output
f_x2 = self.attn(self.norm1(x2))
f_x2 = self.attn(self.ln1(x2))
# apply droppath on attention output
self.seed_cuda('droppath')
f_x2_dropped = build_dropout(self.drop_path_cfg)(f_x2)
@ -233,7 +220,7 @@ class RevTransformerEncoderLayer(BaseModule):
# ffn output
self.seed_cuda('ffn')
g_y1 = self.ffn(self.norm2(y1))
g_y1 = self.ffn(self.ln2(y1))
# apply droppath on ffn output
torch.manual_seed(self.seeds['droppath'])
g_y1_dropped = build_dropout(self.drop_path_cfg)(g_y1)
@ -259,7 +246,7 @@ class RevTransformerEncoderLayer(BaseModule):
y1.requires_grad = True
torch.manual_seed(self.seeds['ffn'])
g_y1 = self.ffn(self.norm2(y1))
g_y1 = self.ffn(self.ln2(y1))
torch.manual_seed(self.seeds['droppath'])
g_y1 = build_dropout(self.drop_path_cfg)(g_y1)
@ -280,7 +267,7 @@ class RevTransformerEncoderLayer(BaseModule):
x2.requires_grad = True
torch.manual_seed(self.seeds['attn'])
f_x2 = self.attn(self.norm1(x2))
f_x2 = self.attn(self.ln1(x2))
torch.manual_seed(self.seeds['droppath'])
f_x2 = build_dropout(self.drop_path_cfg)(f_x2)
@ -337,7 +324,8 @@ class TwoStreamFusion(nn.Module):
class RevVisionTransformer(BaseBackbone):
"""Reversible Vision Transformer.
A PyTorch implementation of : `Reversible Vision Transformers <https://openaccess.thecvf.com/content/CVPR2022/papers/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.pdf>`_ # noqa: E501
A PyTorch implementation of : `Reversible Vision Transformers
<https://openaccess.thecvf.com/content/CVPR2022/html/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.html>`_ # noqa: E501
Args:
arch (str | dict): Vision Transformer architecture. If use string,
@ -357,8 +345,6 @@ class RevVisionTransformer(BaseBackbone):
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
@ -368,15 +354,21 @@ class RevVisionTransformer(BaseBackbone):
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"avg_featmap"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
avg_token (bool): Whether or not to use the mean patch token for
classification. If True, the model will only take the average
of all patch tokens. Defaults to False.
tokens as transformer input. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
@ -443,24 +435,22 @@ class RevVisionTransformer(BaseBackbone):
'feedforward_channels': 768 * 4
}),
}
# Some structures have multiple extra tokens, like DeiT.
num_extra_tokens = 1 # cls_token
num_extra_tokens = 0 # The official RevViT doesn't have class token
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
def __init__(self,
arch='base',
img_size=224,
patch_size=16,
in_channels=3,
out_indices=-1,
drop_rate=0.,
drop_path_rate=0.,
qkv_bias=True,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
out_type='avg_featmap',
with_cls_token=False,
avg_token=True,
frozen_stages=-1,
output_cls_token=False,
interpolate_mode='bicubic',
patch_cfg=dict(),
layer_cfgs=dict(),
@ -501,15 +491,22 @@ class RevVisionTransformer(BaseBackbone):
self.patch_resolution = self.patch_embed.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set cls token
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
self.with_cls_token = with_cls_token
assert with_cls_token is False, 'with_cls_token=True is not supported'
# Set out type
if out_type not in self.OUT_TYPES:
raise ValueError(f'Unsupported `out_type` {out_type}, please '
f'choose from {self.OUT_TYPES}')
self.out_type = out_type
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
# Set cls token
if with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
self.num_extra_tokens = 1
elif out_type != 'cls_token':
self.cls_token = None
self.num_extra_tokens = 0
else:
raise ValueError(
'with_cls_token must be True when `out_type="cls_token"`.')
# Set position embedding
self.interpolate_mode = interpolate_mode
@ -520,20 +517,6 @@ class RevVisionTransformer(BaseBackbone):
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_layers + index
assert 0 <= out_indices[i] <= self.num_layers, \
f'Invalid out_indices {index}'
self.out_indices = out_indices
assert out_indices == [-1] or out_indices == [self.num_layers - 1], \
f'only support output last layer current, but got {out_indices}'
# stochastic depth decay rule
dpr = np.linspace(0, drop_path_rate, self.num_layers)
@ -560,20 +543,12 @@ class RevVisionTransformer(BaseBackbone):
self.frozen_stages = frozen_stages
self.final_norm = final_norm
if final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.embed_dims * 2, postfix=1)
self.add_module(self.norm1_name, norm1)
self.avg_token = avg_token
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims * 2)
# freeze stages only when self.frozen_stages > 0
if self.frozen_stages > 0:
self._freeze_stages()
@property
def norm1(self):
return getattr(self, self.norm1_name)
def init_weights(self):
super(RevVisionTransformer, self).init_weights()
if not (isinstance(self.init_cfg, dict)
@ -618,7 +593,8 @@ class RevVisionTransformer(BaseBackbone):
for param in self.patch_embed.parameters():
param.requires_grad = False
# freeze cls_token
# self.cls_token.requires_grad = False
if self.cls_token is not None:
self.cls_token.requires_grad = False
# freeze layers
for i in range(1, self.frozen_stages + 1):
m = self.layers[i - 1]
@ -627,17 +603,17 @@ class RevVisionTransformer(BaseBackbone):
param.requires_grad = False
# freeze the last layer norm
if self.frozen_stages == len(self.layers) and self.final_norm:
self.norm1.eval()
for param in self.norm1.parameters():
self.ln1.eval()
for param in self.ln1.parameters():
param.requires_grad = False
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.cls_token is not None:
cls_token = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = x + resize_pos_embed(
self.pos_embed,
@ -647,10 +623,6 @@ class RevVisionTransformer(BaseBackbone):
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
x = torch.cat([x, x], dim=-1)
# forward with different conditions
@ -664,33 +636,10 @@ class RevVisionTransformer(BaseBackbone):
x = executing_fn(x, self.layers, [])
if self.final_norm:
x = self.norm1(x)
x = self.ln1(x)
x = self.fusion_layer(x)
if self.with_cls_token:
# RevViT does not allow cls_token
raise NotImplementedError
else:
# (B, H, W, C)
_, __, C = x.shape
patch_token = x.reshape(B, *patch_resolution, C)
# (B, C, H, W)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
if self.avg_token:
# (B, H, W, C)
patch_token = patch_token.permute(0, 2, 3, 1)
# (B, L, C) -> (B, C)
patch_token = patch_token.reshape(
B, patch_resolution[0] * patch_resolution[1], C).mean(dim=1)
if self.output_cls_token:
out = [patch_token, cls_token]
else:
out = patch_token
return tuple([out])
return (self._format_output(x, patch_resolution), )
@staticmethod
def _forward_vanilla_bp(hidden_state, layers, buffer=[]):
@ -706,3 +655,17 @@ class RevVisionTransformer(BaseBackbone):
attn_out, ffn_out = layer(attn_out, ffn_out)
return torch.cat([attn_out, ffn_out], dim=-1)
def _format_output(self, x, hw):
if self.out_type == 'raw':
return x
if self.out_type == 'cls_token':
return x[:, 0]
patch_token = x[:, self.num_extra_tokens:]
if self.out_type == 'featmap':
B = x.size(0)
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
if self.out_type == 'avg_featmap':
return patch_token.mean(dim=1)

View File

@ -5,13 +5,13 @@ from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.registry import MODELS
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed,
to_2tuple)
from .base_backbone import BaseBackbone
@ -70,9 +70,7 @@ class T2TTransformerLayer(BaseModule):
self.v_shortcut = True if input_dims is not None else False
input_dims = input_dims or embed_dims
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, input_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.ln1 = build_norm_layer(norm_cfg, input_dims)
self.attn = MultiheadAttention(
input_dims=input_dims,
@ -85,9 +83,7 @@ class T2TTransformerLayer(BaseModule):
qk_scale=qk_scale or (input_dims // num_heads)**-0.5,
v_shortcut=self.v_shortcut)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
self.ln2 = build_norm_layer(norm_cfg, embed_dims)
self.ffn = FFN(
embed_dims=embed_dims,
@ -97,20 +93,12 @@ class T2TTransformerLayer(BaseModule):
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def forward(self, x):
if self.v_shortcut:
x = self.attn(self.norm1(x))
x = self.attn(self.ln1(x))
else:
x = x + self.attn(self.norm1(x))
x = self.ffn(self.norm2(x), identity=x)
x = x + self.attn(self.ln1(x))
x = self.ffn(self.ln2(x), identity=x)
return x
@ -265,10 +253,19 @@ class T2T_ViT(BaseBackbone):
``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"cls_token"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
t2t_cfg (dict): Extra config of Tokens-to-Token module.
@ -278,7 +275,7 @@ class T2T_ViT(BaseBackbone):
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
num_extra_tokens = 1 # cls_token
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
def __init__(self,
img_size=224,
@ -290,13 +287,13 @@ class T2T_ViT(BaseBackbone):
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
final_norm=True,
out_type='cls_token',
with_cls_token=True,
output_cls_token=True,
interpolate_mode='bicubic',
t2t_cfg=dict(),
layer_cfgs=dict(),
init_cfg=None):
super(T2T_ViT, self).__init__(init_cfg)
super().__init__(init_cfg)
# Token-to-Token Module
self.tokens_to_token = T2TModule(
@ -307,13 +304,22 @@ class T2T_ViT(BaseBackbone):
self.patch_resolution = self.tokens_to_token.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set out type
if out_type not in self.OUT_TYPES:
raise ValueError(f'Unsupported `out_type` {out_type}, please '
f'choose from {self.OUT_TYPES}')
self.out_type = out_type
# Set cls token
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
if with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.num_extra_tokens = 1
elif out_type != 'cls_token':
self.cls_token = None
self.num_extra_tokens = 0
else:
raise ValueError(
'with_cls_token must be True when `out_type="cls_token"`.')
# Set position embedding
self.interpolate_mode = interpolate_mode
@ -360,7 +366,7 @@ class T2T_ViT(BaseBackbone):
self.final_norm = final_norm
if final_norm:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
self.norm = build_norm_layer(norm_cfg, embed_dims)
else:
self.norm = nn.Identity()
@ -401,9 +407,10 @@ class T2T_ViT(BaseBackbone):
B = x.shape[0]
x, patch_resolution = self.tokens_to_token(x)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.cls_token is not None:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = x + resize_pos_embed(
self.pos_embed,
@ -413,10 +420,6 @@ class T2T_ViT(BaseBackbone):
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
outs = []
for i, layer in enumerate(self.encoder):
x = layer(x)
@ -425,19 +428,20 @@ class T2T_ViT(BaseBackbone):
x = self.norm(x)
if i in self.out_indices:
B, _, C = x.shape
if self.with_cls_token:
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
else:
patch_token = x.reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
if self.output_cls_token:
out = [patch_token, cls_token]
else:
out = patch_token
outs.append(out)
outs.append(self._format_output(x, patch_resolution))
return tuple(outs)
def _format_output(self, x, hw):
if self.out_type == 'raw':
return x
if self.out_type == 'cls_token':
return x[:, 0]
patch_token = x[:, self.num_extra_tokens:]
if self.out_type == 'featmap':
B = x.size(0)
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
if self.out_type == 'avg_featmap':
return patch_token.mean(dim=1)

View File

@ -4,13 +4,13 @@ from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.registry import MODELS
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed,
to_2tuple)
from .base_backbone import BaseBackbone
@ -53,9 +53,7 @@ class TransformerEncoderLayer(BaseModule):
self.embed_dims = embed_dims
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
self.attn = MultiheadAttention(
embed_dims=embed_dims,
@ -65,9 +63,7 @@ class TransformerEncoderLayer(BaseModule):
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
self.ffn = FFN(
embed_dims=embed_dims,
@ -79,11 +75,11 @@ class TransformerEncoderLayer(BaseModule):
@property
def norm1(self):
return getattr(self, self.norm1_name)
return self.ln1
@property
def norm2(self):
return getattr(self, self.norm2_name)
return self.ln2
def init_weights(self):
super(TransformerEncoderLayer, self).init_weights()
@ -93,8 +89,8 @@ class TransformerEncoderLayer(BaseModule):
nn.init.normal_(m.bias, std=1e-6)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = self.ffn(self.norm2(x), identity=x)
x = x + self.attn(self.ln1(x))
x = self.ffn(self.ln2(x), identity=x)
return x
@ -134,15 +130,21 @@ class VisionTransformer(BaseBackbone):
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"cls_token"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
avg_token (bool): Whether or not to use the mean patch token for
classification. If True, the model will only take the average
of all patch tokens. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
@ -215,8 +217,8 @@ class VisionTransformer(BaseBackbone):
'feedforward_channels': 768 * 4
}),
}
# Some structures have multiple extra tokens, like DeiT.
num_extra_tokens = 1 # cls_token
num_extra_tokens = 1 # class token
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
def __init__(self,
arch='base',
@ -229,10 +231,9 @@ class VisionTransformer(BaseBackbone):
qkv_bias=True,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
out_type='cls_token',
with_cls_token=True,
avg_token=False,
frozen_stages=-1,
output_cls_token=True,
interpolate_mode='bicubic',
patch_cfg=dict(),
layer_cfgs=dict(),
@ -272,13 +273,21 @@ class VisionTransformer(BaseBackbone):
self.patch_resolution = self.patch_embed.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set out type
if out_type not in self.OUT_TYPES:
raise ValueError(f'Unsupported `out_type` {out_type}, please '
f'choose from {self.OUT_TYPES}')
self.out_type = out_type
# Set cls token
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
if with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
elif out_type != 'cls_token':
self.cls_token = None
self.num_extra_tokens = 0
else:
raise ValueError(
'with_cls_token must be True when `out_type="cls_token"`.')
# Set position embedding
self.interpolate_mode = interpolate_mode
@ -322,34 +331,25 @@ class VisionTransformer(BaseBackbone):
self.frozen_stages = frozen_stages
if pre_norm:
_, norm_layer = build_norm_layer(
norm_cfg, self.embed_dims, postfix=1)
self.pre_norm = build_norm_layer(norm_cfg, self.embed_dims)
else:
norm_layer = nn.Identity()
self.add_module('pre_norm', norm_layer)
self.pre_norm = nn.Identity()
self.final_norm = final_norm
if final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
self.avg_token = avg_token
if avg_token:
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
# freeze stages only when self.frozen_stages > 0
if self.frozen_stages > 0:
self._freeze_stages()
@property
def norm1(self):
return getattr(self, self.norm1_name)
return self.ln1
@property
def norm2(self):
return getattr(self, self.norm2_name)
return self.ln2
def init_weights(self):
super(VisionTransformer, self).init_weights()
@ -407,17 +407,19 @@ class VisionTransformer(BaseBackbone):
param.requires_grad = False
# freeze the last layer norm
if self.frozen_stages == len(self.layers) and self.final_norm:
self.norm1.eval()
for param in self.norm1.parameters():
self.ln1.eval()
for param in self.ln1.parameters():
param.requires_grad = False
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.cls_token is not None:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
@ -427,41 +429,33 @@ class VisionTransformer(BaseBackbone):
x = self.drop_after_pos(x)
x = self.pre_norm(x)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1 and self.final_norm:
x = self.norm1(x)
x = self.ln1(x)
if i in self.out_indices:
B, _, C = x.shape
if self.with_cls_token:
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
else:
patch_token = x.reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
if self.avg_token:
patch_token = patch_token.permute(0, 2, 3, 1)
patch_token = patch_token.reshape(
B, patch_resolution[0] * patch_resolution[1],
C).mean(dim=1)
patch_token = self.norm2(patch_token)
if self.output_cls_token:
out = [patch_token, cls_token]
else:
out = patch_token
outs.append(out)
outs.append(self._format_output(x, patch_resolution))
return tuple(outs)
def _format_output(self, x, hw):
if self.out_type == 'raw':
return x
if self.out_type == 'cls_token':
return x[:, 0]
patch_token = x[:, self.num_extra_tokens:]
if self.out_type == 'featmap':
B = x.size(0)
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
if self.out_type == 'avg_featmap':
return patch_token.mean(dim=1)
def get_layer_depth(self, param_name: str, prefix: str = ''):
"""Get the layer-wise depth of a parameter.

View File

@ -47,7 +47,12 @@ class DeiTClsHead(VisionTransformerClsHead):
the feature of a backbone stage. In ``DeiTClsHead``, we obtain the
feature of the last stage and forward in hidden layer if exists.
"""
_, cls_token, dist_token = feats[-1]
feat = feats[-1] # Obtain feature of the last scale.
# For backward-compatibility with the previous ViT output
if len(feat) == 3:
_, cls_token, dist_token = feat
else:
cls_token, dist_token = feat
if self.hidden_dim is None:
return cls_token, dist_token
else:

View File

@ -80,7 +80,9 @@ class VisionTransformerClsHead(ClsHead):
obtain the feature of the last stage and forward in hidden layer if
exists.
"""
_, cls_token = feats[-1]
feat = feats[-1] # Obtain feature of the last scale.
# For backward-compatibility with the previous ViT output
cls_token = feat[-1] if isinstance(feat, list) else feat
if self.hidden_dim is None:
return cls_token
else:

View File

@ -183,7 +183,6 @@ class ClsBatchNormNeck(BaseModule):
self,
inputs: Tuple[List[torch.Tensor]]) -> Tuple[List[torch.Tensor]]:
"""The forward function."""
# Only apply batch norm to cls token, which is the second tensor in
# each item of the tuple
inputs = [[input_[0], self.bn(input_[1])] for input_ in inputs]
# Only apply batch norm to cls_token
inputs = [self.bn(input_) for input_ in inputs]
return tuple(inputs)

View File

@ -32,8 +32,6 @@ class NonLinearNeck(BaseModule):
Defaults to False.
with_avg_pool (bool): Whether to apply the global average pooling
after backbone. Defaults to True.
vit_backbone (bool): The key to indicate whether the upstream backbone
is ViT. Defaults to False.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to dict(type='SyncBN').
init_cfg (dict or list[dict], optional): Initialization config dict.
@ -50,7 +48,6 @@ class NonLinearNeck(BaseModule):
with_last_bn_affine: bool = True,
with_last_bias: bool = False,
with_avg_pool: bool = True,
vit_backbone: bool = False,
norm_cfg: dict = dict(type='SyncBN'),
init_cfg: Optional[Union[dict, List[dict]]] = [
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
@ -58,7 +55,6 @@ class NonLinearNeck(BaseModule):
) -> None:
super(NonLinearNeck, self).__init__(init_cfg)
self.with_avg_pool = with_avg_pool
self.vit_backbone = vit_backbone
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.relu = nn.ReLU(inplace=True)
@ -104,8 +100,6 @@ class NonLinearNeck(BaseModule):
"""
assert len(x) == 1
x = x[0]
if self.vit_backbone:
x = x[-1]
if self.with_avg_pool:
x = self.avgpool(x)
x = x.view(x.size(0), -1)

View File

@ -140,15 +140,21 @@ class BEiTPretrainViT(BEiTViT):
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
It only works without input mask. Defaults to ``"avg_featmap"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
avg_token (bool): Whether or not to use the mean patch token for
classification. If True, the model will only take the average
of all patch tokens. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
use_abs_pos_emb (bool): Whether or not use absolute position embedding.
Defaults to False.
use_rel_pos_bias (bool): Whether or not use relative position bias.
@ -176,9 +182,8 @@ class BEiTPretrainViT(BEiTViT):
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
avg_token: bool = False,
out_type: str = 'avg_featmap',
frozen_stages: int = -1,
output_cls_token: bool = True,
use_abs_pos_emb: bool = False,
use_rel_pos_bias: bool = False,
use_shared_rel_pos_bias: bool = True,
@ -197,9 +202,9 @@ class BEiTPretrainViT(BEiTViT):
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
final_norm=final_norm,
avg_token=avg_token,
out_type=out_type,
with_cls_token=True,
frozen_stages=frozen_stages,
output_cls_token=output_cls_token,
use_abs_pos_emb=use_abs_pos_emb,
use_shared_rel_pos_bias=use_shared_rel_pos_bias,
use_rel_pos_bias=use_rel_pos_bias,

View File

@ -217,8 +217,17 @@ class CAEPretrainViT(BEiTViT):
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
`with_cls_token` must be True. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
It only works without input mask. Defaults to ``"avg_featmap"``.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
layer_scale_init_value (float, optional): The init value of gamma in
@ -242,10 +251,8 @@ class CAEPretrainViT(BEiTViT):
bias: bool = 'qv_bias',
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
with_cls_token: bool = True,
avg_token: bool = False,
out_type: str = 'avg_featmap',
frozen_stages: int = -1,
output_cls_token: bool = True,
use_abs_pos_emb: bool = True,
use_rel_pos_bias: bool = False,
use_shared_rel_pos_bias: bool = False,
@ -270,10 +277,9 @@ class CAEPretrainViT(BEiTViT):
bias=bias,
norm_cfg=norm_cfg,
final_norm=final_norm,
with_cls_token=with_cls_token,
avg_token=avg_token,
out_type=out_type,
with_cls_token=True,
frozen_stages=frozen_stages,
output_cls_token=output_cls_token,
use_abs_pos_emb=use_abs_pos_emb,
use_rel_pos_bias=use_rel_pos_bias,
use_shared_rel_pos_bias=use_shared_rel_pos_bias,

View File

@ -33,8 +33,17 @@ class MAEViT(VisionTransformer):
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
`with_cls_token` must be True. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
It only works without input mask. Defaults to ``"avg_featmap"``.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
@ -55,7 +64,7 @@ class MAEViT(VisionTransformer):
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
output_cls_token: bool = True,
out_type: str = 'avg_featmap',
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),
@ -70,7 +79,8 @@ class MAEViT(VisionTransformer):
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
final_norm=final_norm,
output_cls_token=output_cls_token,
out_type=out_type,
with_cls_token=True,
interpolate_mode=interpolate_mode,
patch_cfg=patch_cfg,
layer_cfgs=layer_cfgs,

View File

@ -178,8 +178,17 @@ class MaskFeatViT(VisionTransformer):
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
`with_cls_token` must be True. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
It only works without input mask. Defaults to ``"avg_featmap"``.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
@ -198,7 +207,7 @@ class MaskFeatViT(VisionTransformer):
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
output_cls_token: bool = True,
out_type: str = 'avg_featmap',
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),
@ -212,7 +221,8 @@ class MaskFeatViT(VisionTransformer):
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
final_norm=final_norm,
output_cls_token=output_cls_token,
out_type=out_type,
with_cls_token=True,
interpolate_mode=interpolate_mode,
patch_cfg=patch_cfg,
layer_cfgs=layer_cfgs,

View File

@ -8,6 +8,14 @@ import torch.nn as nn
from mmengine.model import BaseModule
from mmengine.utils import digit_version
# After pytorch v1.10.0, use torch.meshgrid without indexing
# will raise extra warning. For more details,
# refers to https://github.com/pytorch/pytorch/issues/50276
if digit_version(torch.__version__) >= digit_version('1.10.0'):
torch_meshgrid = partial(torch.meshgrid, indexing='ij')
else:
torch_meshgrid = torch.meshgrid
class ConditionalPositionEncoding(BaseModule):
"""The Conditional Position Encoding (CPE) module.
@ -137,7 +145,7 @@ def build_2d_sincos_position_embedding(
h, w = patches_resolution
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
grid_w, grid_h = torch_meshgrid(grid_w, grid_h)
assert embed_dims % 4 == 0, \
'Embed dimension must be divisible by 4.'
pos_dim = embed_dims // 4

View File

@ -82,18 +82,16 @@ class TestBEiT(TestCase):
def test_forward(self):
imgs = torch.randn(1, 3, 224, 224)
# test with output_cls_token
cfg = deepcopy(self.cfg)
cfg['output_cls_token'] = True
cfg['out_type'] = 'cls_token'
model = BEiTViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 768))
cls_token = outs[-1]
self.assertEqual(cls_token.shape, (1, 768))
# test without output_cls_token
# test without output cls_token
cfg = deepcopy(self.cfg)
model = BEiTViT(**cfg)
outs = model(imgs)
@ -104,7 +102,7 @@ class TestBEiT(TestCase):
# test without average
cfg = deepcopy(self.cfg)
cfg['avg_token'] = False
cfg['out_type'] = 'featmap'
model = BEiTViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)

View File

@ -72,7 +72,6 @@ class TestCSPDarkNet(TestCase):
def test_forward(self):
imgs = torch.randn(3, 3, 224, 224)
# test without output_cls_token
cfg = deepcopy(self.cfg)
model = self.class_name(**cfg)
outs = model(imgs)

View File

@ -62,37 +62,19 @@ class TestDeiT(TestCase):
def test_forward(self):
imgs = torch.randn(1, 3, 224, 224)
# test with_cls_token=False
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['output_cls_token'] = True
with self.assertRaisesRegex(AssertionError, 'but got False'):
DistilledVisionTransformer(**cfg)
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['output_cls_token'] = False
model = DistilledVisionTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 192, 14, 14))
# test with output_cls_token
# test with output cls_token
cfg = deepcopy(self.cfg)
model = DistilledVisionTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token, dist_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 192, 14, 14))
cls_token, dist_token = outs[-1]
self.assertEqual(cls_token.shape, (1, 192))
self.assertEqual(dist_token.shape, (1, 192))
# test without output_cls_token
# test without output cls_token
cfg = deepcopy(self.cfg)
cfg['output_cls_token'] = False
cfg['out_type'] = 'featmap'
model = DistilledVisionTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
@ -108,8 +90,7 @@ class TestDeiT(TestCase):
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 3)
for out in outs:
patch_token, cls_token, dist_token = out
self.assertEqual(patch_token.shape, (1, 192, 14, 14))
cls_token, dist_token = out
self.assertEqual(cls_token.shape, (1, 192))
self.assertEqual(dist_token.shape, (1, 192))
@ -118,14 +99,13 @@ class TestDeiT(TestCase):
imgs2 = torch.randn(1, 3, 256, 256)
imgs3 = torch.randn(1, 3, 256, 309)
cfg = deepcopy(self.cfg)
cfg['out_type'] = 'featmap'
model = DistilledVisionTransformer(**cfg)
for imgs in [imgs1, imgs2, imgs3]:
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token, dist_token = outs[-1]
featmap = outs[-1]
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
math.ceil(imgs.shape[3] / 16))
self.assertEqual(patch_token.shape, (1, 192, *expect_feat_shape))
self.assertEqual(cls_token.shape, (1, 192))
self.assertEqual(dist_token.shape, (1, 192))
self.assertEqual(featmap.shape, (1, 192, *expect_feat_shape))

View File

@ -116,13 +116,13 @@ class TestDeiT3(TestCase):
# test with_cls_token=False
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['output_cls_token'] = True
with self.assertRaisesRegex(AssertionError, 'but got False'):
cfg['out_type'] = 'cls_token'
with self.assertRaisesRegex(ValueError, 'must be True'):
DeiT3(**cfg)
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['output_cls_token'] = False
cfg['out_type'] = 'featmap'
model = DeiT3(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
@ -130,26 +130,15 @@ class TestDeiT3(TestCase):
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
# test with output_cls_token
# test with output cls_token
cfg = deepcopy(self.cfg)
model = DeiT3(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
cls_token = outs[-1]
self.assertEqual(cls_token.shape, (1, 768))
# test without output_cls_token
cfg = deepcopy(self.cfg)
cfg['output_cls_token'] = False
model = DeiT3(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
# Test forward with multi out indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = [-3, -2, -1]
@ -158,8 +147,7 @@ class TestDeiT3(TestCase):
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 3)
for out in outs:
patch_token, cls_token = out
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
cls_token = out
self.assertEqual(cls_token.shape, (1, 768))
# Test forward with dynamic input size
@ -167,13 +155,13 @@ class TestDeiT3(TestCase):
imgs2 = torch.randn(1, 3, 256, 256)
imgs3 = torch.randn(1, 3, 256, 309)
cfg = deepcopy(self.cfg)
cfg['out_type'] = 'featmap'
model = DeiT3(**cfg)
for imgs in [imgs1, imgs2, imgs3]:
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token = outs[-1]
featmap = outs[-1]
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
math.ceil(imgs.shape[3] / 16))
self.assertEqual(patch_token.shape, (1, 768, *expect_feat_shape))
self.assertEqual(cls_token.shape, (1, 768))
self.assertEqual(featmap.shape, (1, 768, *expect_feat_shape))

View File

@ -49,16 +49,6 @@ class TestRevVisionTransformer(TestCase):
self.assertEqual(layer.attn.num_heads, 16)
self.assertEqual(layer.ffn.feedforward_channels, 1024)
# Test out_indices
# TODO: to be implemented, current only support last layer
cfg = deepcopy(self.cfg)
cfg['out_indices'] = {1: 1}
with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
RevVisionTransformer(**cfg)
cfg['out_indices'] = [13]
with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'):
RevVisionTransformer(**cfg)
# Test model structure
cfg = deepcopy(self.cfg)
model = RevVisionTransformer(**cfg)
@ -108,8 +98,8 @@ class TestRevVisionTransformer(TestCase):
cfg['img_size'] = 384
model = RevVisionTransformer(**cfg)
load_checkpoint(model, checkpoint, strict=True)
resized_pos_embed = timm_resize_pos_embed(pretrain_pos_embed,
model.pos_embed)
resized_pos_embed = timm_resize_pos_embed(
pretrain_pos_embed, model.pos_embed, num_tokens=0)
self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed))
os.remove(checkpoint)
@ -119,7 +109,7 @@ class TestRevVisionTransformer(TestCase):
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['output_cls_token'] = False
cfg['out_type'] = 'avg_featmap'
model = RevVisionTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
@ -137,5 +127,5 @@ class TestRevVisionTransformer(TestCase):
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
avg_token = outs[-1]
self.assertEqual(avg_token.shape, (1, 768 * 2))
avg_featmap = outs[-1]
self.assertEqual(avg_featmap.shape, (1, 768 * 2))

View File

@ -94,13 +94,13 @@ class TestT2TViT(TestCase):
# test with_cls_token=False
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['output_cls_token'] = True
with self.assertRaisesRegex(AssertionError, 'but got False'):
cfg['out_type'] = 'cls_token'
with self.assertRaisesRegex(ValueError, 'must be True'):
T2T_ViT(**cfg)
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['output_cls_token'] = False
cfg['out_type'] = 'featmap'
model = T2T_ViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
@ -108,26 +108,15 @@ class TestT2TViT(TestCase):
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
# test with output_cls_token
# test with output cls_token
cfg = deepcopy(self.cfg)
model = T2T_ViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
cls_token = outs[-1]
self.assertEqual(cls_token.shape, (1, 384))
# test without output_cls_token
cfg = deepcopy(self.cfg)
cfg['output_cls_token'] = False
model = T2T_ViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
# Test forward with multi out indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = [-3, -2, -1]
@ -136,22 +125,20 @@ class TestT2TViT(TestCase):
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 3)
for out in outs:
patch_token, cls_token = out
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
self.assertEqual(cls_token.shape, (1, 384))
self.assertEqual(out.shape, (1, 384))
# Test forward with dynamic input size
imgs1 = torch.randn(1, 3, 224, 224)
imgs2 = torch.randn(1, 3, 256, 256)
imgs3 = torch.randn(1, 3, 256, 309)
cfg = deepcopy(self.cfg)
cfg['out_type'] = 'featmap'
model = T2T_ViT(**cfg)
for imgs in [imgs1, imgs2, imgs3]:
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token = outs[-1]
patch_token = outs[-1]
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
math.ceil(imgs.shape[3] / 16))
self.assertEqual(patch_token.shape, (1, 384, *expect_feat_shape))
self.assertEqual(cls_token.shape, (1, 384))

View File

@ -126,13 +126,13 @@ class TestVisionTransformer(TestCase):
# test with_cls_token=False
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['output_cls_token'] = True
with self.assertRaisesRegex(AssertionError, 'but got False'):
cfg['out_type'] = 'cls_token'
with self.assertRaisesRegex(ValueError, 'must be True'):
VisionTransformer(**cfg)
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['output_cls_token'] = False
cfg['out_type'] = 'featmap'
model = VisionTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
@ -140,26 +140,15 @@ class TestVisionTransformer(TestCase):
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
# test with output_cls_token
# test with output cls_token
cfg = deepcopy(self.cfg)
model = VisionTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
cls_token = outs[-1]
self.assertEqual(cls_token.shape, (1, 768))
# test without output_cls_token
cfg = deepcopy(self.cfg)
cfg['output_cls_token'] = False
model = VisionTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
# Test forward with multi out indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = [-3, -2, -1]
@ -168,22 +157,20 @@ class TestVisionTransformer(TestCase):
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 3)
for out in outs:
patch_token, cls_token = out
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
self.assertEqual(cls_token.shape, (1, 768))
self.assertEqual(out.shape, (1, 768))
# Test forward with dynamic input size
imgs1 = torch.randn(1, 3, 224, 224)
imgs2 = torch.randn(1, 3, 256, 256)
imgs3 = torch.randn(1, 3, 256, 309)
cfg = deepcopy(self.cfg)
cfg['out_type'] = 'featmap'
model = VisionTransformer(**cfg)
for imgs in [imgs1, imgs2, imgs3]:
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token = outs[-1]
patch_token = outs[-1]
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
math.ceil(imgs.shape[3] / 16))
self.assertEqual(patch_token.shape, (1, 768, *expect_feat_shape))
self.assertEqual(cls_token.shape, (1, 768))

View File

@ -35,9 +35,7 @@ class TestBEiT(TestCase):
# test without mask
fake_outputs = beit_backbone(fake_inputs, None)
assert len(fake_outputs[0]) == 2
assert fake_outputs[0][0].shape == torch.Size([2, 768, 14, 14])
assert fake_outputs[0][1].shape == torch.Size([2, 768])
assert fake_outputs[0].shape == torch.Size([2, 768])
@pytest.mark.skipif(
platform.system() == 'Windows', reason='Windows mem limit')
@ -111,10 +109,9 @@ class TestBEiT(TestCase):
drop_path_rate=0.,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
out_type='featmap',
with_cls_token=True,
avg_token=False,
frozen_stages=-1,
output_cls_token=False,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,

View File

@ -25,9 +25,7 @@ def test_cae_vit():
# test without mask
fake_outputs = cae_backbone(fake_inputs, None)
assert len(fake_outputs[0]) == 2
assert fake_outputs[0][0].shape == torch.Size([1, 192, 14, 14])
assert fake_outputs[0][1].shape == torch.Size([1, 192])
assert fake_outputs[0].shape == torch.Size([1, 192])
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')

View File

@ -21,9 +21,7 @@ def test_mae_vit():
# test without mask
fake_outputs = mae_backbone(fake_inputs, None)
assert len(fake_outputs[0]) == 2
assert fake_outputs[0][0].shape == torch.Size([2, 768, 14, 14])
assert fake_outputs[0][1].shape == torch.Size([2, 768])
assert fake_outputs[0].shape == torch.Size([2, 768])
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')

View File

@ -22,9 +22,7 @@ def test_maskfeat_vit():
# test without mask
fake_outputs = maskfeat_backbone(fake_inputs, None)
assert len(fake_outputs[0]) == 2
assert fake_outputs[0][0].shape == torch.Size([2, 768, 14, 14])
assert fake_outputs[0][1].shape == torch.Size([2, 768])
assert fake_outputs[0].shape == torch.Size([2, 768])
@pytest.mark.skipif(

View File

@ -24,9 +24,7 @@ def test_milan_vit():
# test without mask
fake_outputs = milan_backbone(fake_inputs, None)
assert len(fake_outputs[0]) == 2
assert fake_outputs[0][0].shape == torch.Size([2, 768, 14, 14])
assert fake_outputs[0][1].shape == torch.Size([2, 768])
assert fake_outputs[0].shape == torch.Size([2, 768])
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')

View File

@ -29,7 +29,6 @@ class TestMoCoV3(TestCase):
with_last_bn_affine=False,
with_last_bias=False,
with_avg_pool=False,
vit_backbone=True,
norm_cfg=dict(type='BN1d'))
head = dict(
type='MoCoV3Head',
@ -89,5 +88,4 @@ class TestMoCoV3(TestCase):
# test extract
fake_feats = alg(fake_inputs['inputs'][0], mode='tensor')
self.assertEqual(fake_feats[0][0].size(), torch.Size([2, 384, 14, 14]))
self.assertEqual(fake_feats[0][1].size(), torch.Size([2, 384]))
self.assertEqual(fake_feats[0].size(), torch.Size([2, 384]))

View File

@ -50,10 +50,9 @@ class TestVQKD(TestCase):
drop_path_rate=0.,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
out_type='featmap',
with_cls_token=True,
avg_token=False,
frozen_stages=-1,
output_cls_token=False,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,