[Refactor] Use `out_type` to specify ViT-like backbone output. (#1408)
* [Refactor] Use to specify ViT-like backbone output. * Fix ClsBatchNormNeck * Update mmpretrain/models/necks/mae_neck.py --------- Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>pull/1418/head
parent
63e5b512cc
commit
dbf3df21a3
|
@ -72,7 +72,7 @@ val_dataloader = dict(
|
|||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root='data/imagenet',
|
||||
# ann_file='meta/val.txt',
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
|
|
|
@ -5,9 +5,8 @@ model = dict(
|
|||
arch='eva-g',
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
avg_token=True,
|
||||
layer_scale_init_value=0.0,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
|
|
|
@ -5,9 +5,8 @@ model = dict(
|
|||
arch='l',
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
avg_token=True,
|
||||
layer_scale_init_value=0.0,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
|
|
|
@ -6,9 +6,7 @@ model = dict(
|
|||
arch='deit-base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
output_cls_token=False,
|
||||
avg_token=True,
|
||||
with_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
),
|
||||
neck=None,
|
||||
head=dict(
|
||||
|
|
|
@ -6,9 +6,7 @@ model = dict(
|
|||
arch='deit-small',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
output_cls_token=False,
|
||||
avg_token=True,
|
||||
with_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
),
|
||||
neck=None,
|
||||
head=dict(
|
||||
|
|
|
@ -21,8 +21,7 @@ model = dict(
|
|||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_path_rate=0.1,
|
||||
avg_token=True,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
use_abs_pos_emb=False,
|
||||
use_rel_pos_bias=True,
|
||||
use_shared_rel_pos_bias=False),
|
||||
|
|
|
@ -20,8 +20,7 @@ model = dict(
|
|||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
avg_token=True,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
use_abs_pos_emb=False,
|
||||
use_rel_pos_bias=True,
|
||||
use_shared_rel_pos_bias=False,
|
||||
|
|
|
@ -14,10 +14,9 @@ vqkd_encoder = dict(
|
|||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
final_norm=True,
|
||||
out_type='featmap',
|
||||
with_cls_token=True,
|
||||
avg_token=False,
|
||||
frozen_stages=-1,
|
||||
output_cls_token=False,
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
|
|
|
@ -14,10 +14,9 @@ vqkd_encoder = dict(
|
|||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
final_norm=True,
|
||||
out_type='featmap',
|
||||
with_cls_token=True,
|
||||
avg_token=False,
|
||||
frozen_stages=-1,
|
||||
output_cls_token=False,
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
|
|
|
@ -14,8 +14,7 @@ model = dict(
|
|||
patch_size=16,
|
||||
# 0.2 for 1600 epochs pretrained models and 0.1 for 300 epochs.
|
||||
drop_path_rate=0.1,
|
||||
avg_token=True,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
use_abs_pos_emb=False,
|
||||
use_rel_pos_bias=True,
|
||||
use_shared_rel_pos_bias=False),
|
||||
|
|
|
@ -11,8 +11,7 @@ model = dict(
|
|||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
avg_token=True,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
use_abs_pos_emb=False,
|
||||
use_rel_pos_bias=True,
|
||||
use_shared_rel_pos_bias=False,
|
||||
|
|
|
@ -67,11 +67,10 @@ model = dict(
|
|||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
avg_token=True, # use average token for cls head
|
||||
final_norm=False, # do not use final norm
|
||||
drop_path_rate=0.1,
|
||||
layer_scale_init_value=0.1,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=True,
|
||||
use_shared_rel_pos_bias=False,
|
||||
|
|
|
@ -56,8 +56,7 @@ model = dict(
|
|||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_path_rate=0.1,
|
||||
avg_token=True,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
final_norm=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint='')),
|
||||
neck=None,
|
||||
|
|
|
@ -17,7 +17,7 @@ model = dict(
|
|||
img_size=224,
|
||||
patch_size=16,
|
||||
frozen_stages=12,
|
||||
avg_token=False,
|
||||
out_type='cls_token',
|
||||
final_norm=True,
|
||||
init_cfg=dict(type='Pretrained', checkpoint='')),
|
||||
neck=dict(type='ClsBatchNormNeck', input_features=768),
|
||||
|
|
|
@ -5,9 +5,8 @@ model = dict(
|
|||
arch='eva-g',
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
avg_token=True,
|
||||
layer_scale_init_value=0.0,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
|
|
|
@ -5,9 +5,8 @@ model = dict(
|
|||
arch='eva-g',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
avg_token=True,
|
||||
layer_scale_init_value=0.0,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
|
|
|
@ -5,9 +5,8 @@ model = dict(
|
|||
arch='l',
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
avg_token=True,
|
||||
layer_scale_init_value=0.0,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
|
|
|
@ -55,8 +55,7 @@ model = dict(
|
|||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_path_rate=0.1,
|
||||
avg_token=True,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
final_norm=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint='')),
|
||||
neck=None,
|
||||
|
|
|
@ -18,7 +18,7 @@ model = dict(
|
|||
img_size=224,
|
||||
patch_size=16,
|
||||
frozen_stages=12,
|
||||
avg_token=False,
|
||||
out_type='cls_token',
|
||||
final_norm=True,
|
||||
init_cfg=dict(type='Pretrained', checkpoint='')),
|
||||
neck=dict(type='ClsBatchNormNeck', input_features=768),
|
||||
|
|
|
@ -57,8 +57,7 @@ model = dict(
|
|||
img_size=448,
|
||||
patch_size=14,
|
||||
drop_path_rate=0.3, # set to 0.3
|
||||
avg_token=True,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
final_norm=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint='')),
|
||||
neck=None,
|
||||
|
|
|
@ -56,8 +56,7 @@ model = dict(
|
|||
img_size=224,
|
||||
patch_size=14,
|
||||
drop_path_rate=0.3, # set to 0.3
|
||||
avg_token=True,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
final_norm=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint='')),
|
||||
neck=None,
|
||||
|
|
|
@ -56,8 +56,7 @@ model = dict(
|
|||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_path_rate=0.2, # set to 0.2
|
||||
avg_token=True,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
final_norm=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint='')),
|
||||
neck=None,
|
||||
|
|
|
@ -18,7 +18,7 @@ model = dict(
|
|||
img_size=224,
|
||||
patch_size=16,
|
||||
frozen_stages=24,
|
||||
avg_token=False,
|
||||
out_type='cls_token',
|
||||
final_norm=True,
|
||||
init_cfg=dict(type='Pretrained', checkpoint='')),
|
||||
neck=dict(type='ClsBatchNormNeck', input_features=1024),
|
||||
|
|
|
@ -54,8 +54,7 @@ model = dict(
|
|||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_path_rate=0.1,
|
||||
avg_token=True,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
final_norm=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint='')),
|
||||
neck=None,
|
||||
|
|
|
@ -56,8 +56,7 @@ model = dict(
|
|||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_path_rate=0.1,
|
||||
avg_token=True,
|
||||
output_cls_token=False,
|
||||
out_type='avg_featmap',
|
||||
final_norm=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint='')),
|
||||
neck=None,
|
||||
|
|
|
@ -17,7 +17,7 @@ model = dict(
|
|||
img_size=224,
|
||||
patch_size=16,
|
||||
frozen_stages=12,
|
||||
avg_token=False,
|
||||
out_type='cls_token',
|
||||
final_norm=True,
|
||||
init_cfg=dict(type='Pretrained', checkpoint='')),
|
||||
neck=dict(type='ClsBatchNormNeck', input_features=768),
|
||||
|
|
|
@ -19,8 +19,7 @@ model = dict(
|
|||
with_last_bn=True,
|
||||
with_last_bn_affine=False,
|
||||
with_last_bias=False,
|
||||
with_avg_pool=True,
|
||||
vit_backbone=False),
|
||||
with_avg_pool=True),
|
||||
head=dict(
|
||||
type='MoCoV3Head',
|
||||
predictor=dict(
|
||||
|
|
|
@ -19,8 +19,7 @@ model = dict(
|
|||
with_last_bn=True,
|
||||
with_last_bn_affine=False,
|
||||
with_last_bias=False,
|
||||
with_avg_pool=True,
|
||||
vit_backbone=False),
|
||||
with_avg_pool=True),
|
||||
head=dict(
|
||||
type='MoCoV3Head',
|
||||
predictor=dict(
|
||||
|
|
|
@ -19,8 +19,7 @@ model = dict(
|
|||
with_last_bn=True,
|
||||
with_last_bn_affine=False,
|
||||
with_last_bias=False,
|
||||
with_avg_pool=True,
|
||||
vit_backbone=False),
|
||||
with_avg_pool=True),
|
||||
head=dict(
|
||||
type='MoCoV3Head',
|
||||
predictor=dict(
|
||||
|
|
|
@ -99,8 +99,7 @@ model = dict(
|
|||
with_last_bn=True,
|
||||
with_last_bn_affine=False,
|
||||
with_last_bias=False,
|
||||
with_avg_pool=False,
|
||||
vit_backbone=True),
|
||||
with_avg_pool=False),
|
||||
head=dict(
|
||||
type='MoCoV3Head',
|
||||
predictor=dict(
|
||||
|
|
|
@ -99,8 +99,7 @@ model = dict(
|
|||
with_last_bn=True,
|
||||
with_last_bn_affine=False,
|
||||
with_last_bias=False,
|
||||
with_avg_pool=False,
|
||||
vit_backbone=True),
|
||||
with_avg_pool=False),
|
||||
head=dict(
|
||||
type='MoCoV3Head',
|
||||
predictor=dict(
|
||||
|
|
|
@ -99,8 +99,7 @@ model = dict(
|
|||
with_last_bn=True,
|
||||
with_last_bn_affine=False,
|
||||
with_last_bias=False,
|
||||
with_avg_pool=False,
|
||||
vit_backbone=True),
|
||||
with_avg_pool=False),
|
||||
head=dict(
|
||||
type='MoCoV3Head',
|
||||
predictor=dict(
|
||||
|
|
|
@ -25,7 +25,7 @@ Models:
|
|||
Top 1 Accuracy: 79.87
|
||||
Top 5 Accuracy: 94.90
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/revvit/revvit-base_3rdparty_in1k_20221213-87a7b0a5.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/revvit/revvit-small_3rdparty_in1k_20221213-a3a34f5c.pth
|
||||
Config: configs/revvit/revvit-small_8xb256_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/pyslowfast/rev/REV_VIT_S.pyth
|
||||
|
@ -41,7 +41,7 @@ Models:
|
|||
Top 1 Accuracy: 81.81
|
||||
Top 5 Accuracy: 95.56
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/revvit/revvit-small_3rdparty_in1k_20221213-a3a34f5c.pth
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/revvit/revvit-base_3rdparty_in1k_20221213-87a7b0a5.pth
|
||||
Config: configs/revvit/revvit-base_8xb256_in1k.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/pyslowfast/rev/REV_VIT_B.pyth
|
||||
|
|
|
@ -9,8 +9,7 @@ model = dict(
|
|||
patch_size=16,
|
||||
out_indices=-1,
|
||||
drop_path_rate=0.1,
|
||||
avg_token=False,
|
||||
output_cls_token=False,
|
||||
out_type='featmap',
|
||||
final_norm=False),
|
||||
neck=None,
|
||||
head=dict(
|
||||
|
|
|
@ -4,13 +4,12 @@ from typing import List, Optional, Sequence, Tuple, Union
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import (BEiTAttention, resize_pos_embed,
|
||||
from ..utils import (BEiTAttention, build_norm_layer, resize_pos_embed,
|
||||
resize_relative_position_bias_table, to_2tuple)
|
||||
from .vision_transformer import TransformerEncoderLayer, VisionTransformer
|
||||
|
||||
|
@ -203,12 +202,12 @@ class BEiTTransformerEncoderLayer(TransformerEncoderLayer):
|
|||
rel_pos_bias: torch.Tensor) -> torch.Tensor:
|
||||
if self.gamma_1 is None:
|
||||
x = x + self.drop_path(
|
||||
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
||||
x = x + self.drop_path(self.ffn(self.norm2(x)))
|
||||
self.attn(self.ln1(x), rel_pos_bias=rel_pos_bias))
|
||||
x = x + self.drop_path(self.ffn(self.ln2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(
|
||||
self.norm1(x), rel_pos_bias=rel_pos_bias))
|
||||
x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x)))
|
||||
self.ln1(x), rel_pos_bias=rel_pos_bias))
|
||||
x = x + self.drop_path(self.gamma_2 * self.ffn(self.ln2(x)))
|
||||
return x
|
||||
|
||||
|
||||
|
@ -251,15 +250,21 @@ class BEiTViT(VisionTransformer):
|
|||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
out_type (str): The type of output features. Please choose from
|
||||
|
||||
- ``"cls_token"``: The class token tensor with shape (B, C).
|
||||
- ``"featmap"``: The feature map tensor from the patch tokens
|
||||
with shape (B, C, H, W).
|
||||
- ``"avg_featmap"``: The global averaged feature map tensor
|
||||
with shape (B, C).
|
||||
- ``"raw"``: The raw feature tensor includes patch tokens and
|
||||
class tokens with shape (B, L, C).
|
||||
|
||||
Defaults to ``"avg_featmap"``.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
avg_token (bool): Whether or not to use the mean patch token for
|
||||
classification. If True, the model will only take the average
|
||||
of all patch tokens. Defaults to False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
use_abs_pos_emb (bool): Use position embedding like vanilla ViT.
|
||||
Defaults to False.
|
||||
use_rel_pos_bias (bool): Use relative position embedding in each
|
||||
|
@ -289,10 +294,9 @@ class BEiTViT(VisionTransformer):
|
|||
bias='qv_bias',
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
final_norm=False,
|
||||
out_type='avg_featmap',
|
||||
with_cls_token=True,
|
||||
avg_token=True,
|
||||
frozen_stages=-1,
|
||||
output_cls_token=False,
|
||||
use_abs_pos_emb=False,
|
||||
use_rel_pos_bias=True,
|
||||
use_shared_rel_pos_bias=False,
|
||||
|
@ -334,17 +338,25 @@ class BEiTViT(VisionTransformer):
|
|||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
# Set cls token
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
# Set out type
|
||||
if out_type not in self.OUT_TYPES:
|
||||
raise ValueError(f'Unsupported `out_type` {out_type}, please '
|
||||
f'choose from {self.OUT_TYPES}')
|
||||
self.out_type = out_type
|
||||
|
||||
self.interpolate_mode = interpolate_mode
|
||||
# Set cls token
|
||||
if with_cls_token:
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
self.num_extra_tokens = 1
|
||||
elif out_type != 'cls_token':
|
||||
self.cls_token = None
|
||||
self.num_extra_tokens = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
'with_cls_token must be True when `out_type="cls_token"`.')
|
||||
|
||||
# Set position embedding
|
||||
self.interpolate_mode = interpolate_mode
|
||||
if use_abs_pos_emb:
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + self.num_extra_tokens,
|
||||
|
@ -405,15 +417,10 @@ class BEiTViT(VisionTransformer):
|
|||
self.frozen_stages = frozen_stages
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
|
||||
self.avg_token = avg_token
|
||||
if avg_token:
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=2)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
if out_type == 'avg_featmap':
|
||||
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
|
||||
# freeze stages only when self.frozen_stages > 0
|
||||
if self.frozen_stages > 0:
|
||||
|
@ -423,9 +430,10 @@ class BEiTViT(VisionTransformer):
|
|||
B = x.shape[0]
|
||||
x, patch_resolution = self.patch_embed(x)
|
||||
|
||||
if self.cls_token is not None:
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
cls_token = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
|
||||
if self.pos_embed is not None:
|
||||
x = x + resize_pos_embed(
|
||||
|
@ -439,42 +447,32 @@ class BEiTViT(VisionTransformer):
|
|||
rel_pos_bias = self.rel_pos_bias() \
|
||||
if self.rel_pos_bias is not None else None
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x, rel_pos_bias)
|
||||
|
||||
if i == len(self.layers) - 1 and self.final_norm:
|
||||
x = self.norm1(x)
|
||||
x = self.ln1(x)
|
||||
|
||||
if i in self.out_indices:
|
||||
B, _, C = x.shape
|
||||
if self.with_cls_token:
|
||||
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
else:
|
||||
patch_token = x.reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = None
|
||||
|
||||
if self.avg_token:
|
||||
patch_token = patch_token.permute(0, 2, 3, 1)
|
||||
patch_token = patch_token.reshape(
|
||||
B, patch_resolution[0] * patch_resolution[1],
|
||||
C).mean(dim=1)
|
||||
patch_token = self.norm2(patch_token)
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token]
|
||||
else:
|
||||
out = patch_token
|
||||
outs.append(out)
|
||||
outs.append(self._format_output(x, patch_resolution))
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def _format_output(self, x, hw):
|
||||
if self.out_type == 'raw':
|
||||
return x
|
||||
if self.out_type == 'cls_token':
|
||||
return x[:, 0]
|
||||
|
||||
patch_token = x[:, self.num_extra_tokens:]
|
||||
if self.out_type == 'featmap':
|
||||
B = x.size(0)
|
||||
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
|
||||
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
|
||||
if self.out_type == 'avg_featmap':
|
||||
return self.ln2(patch_token.mean(dim=1))
|
||||
|
||||
def _prepare_relative_position_bias_table(self, state_dict, prefix, *args,
|
||||
**kwargs):
|
||||
from mmengine.logging import MMLogger
|
||||
|
|
|
@ -43,10 +43,18 @@ class DistilledVisionTransformer(VisionTransformer):
|
|||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
out_type (str): The type of output features. Please choose from
|
||||
|
||||
- ``"cls_token"``: A tuple with the class token and the
|
||||
distillation token. The shapes of both tensor are (B, C).
|
||||
- ``"featmap"``: The feature map tensor from the patch tokens
|
||||
with shape (B, C, H, W).
|
||||
- ``"avg_featmap"``: The global averaged feature map tensor
|
||||
with shape (B, C).
|
||||
- ``"raw"``: The raw feature tensor includes patch tokens and
|
||||
class tokens with shape (B, L, C).
|
||||
|
||||
Defaults to ``"cls_token"``.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
|
@ -55,11 +63,15 @@ class DistilledVisionTransformer(VisionTransformer):
|
|||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
num_extra_tokens = 2 # cls_token, dist_token
|
||||
num_extra_tokens = 2 # class token and distillation token
|
||||
|
||||
def __init__(self, arch='deit-base', *args, **kwargs):
|
||||
super(DistilledVisionTransformer, self).__init__(
|
||||
arch=arch, *args, **kwargs)
|
||||
arch=arch,
|
||||
with_cls_token=True,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -78,37 +90,24 @@ class DistilledVisionTransformer(VisionTransformer):
|
|||
num_extra_tokens=self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 2:]
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
|
||||
if i == len(self.layers) - 1 and self.final_norm:
|
||||
x = self.norm1(x)
|
||||
x = self.ln1(x)
|
||||
|
||||
if i in self.out_indices:
|
||||
B, _, C = x.shape
|
||||
if self.with_cls_token:
|
||||
patch_token = x[:, 2:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
dist_token = x[:, 1]
|
||||
else:
|
||||
patch_token = x.reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = None
|
||||
dist_token = None
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token, dist_token]
|
||||
else:
|
||||
out = patch_token
|
||||
outs.append(out)
|
||||
outs.append(self._format_output(x, patch_resolution))
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def _format_output(self, x, hw):
|
||||
if self.out_type == 'cls_token':
|
||||
return x[:, 0], x[:, 1]
|
||||
|
||||
return super()._format_output(x, hw)
|
||||
|
||||
def init_weights(self):
|
||||
super(DistilledVisionTransformer, self).init_weights()
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Sequence
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
|
||||
from mmcv.cnn import Linear, build_activation_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
|
@ -11,7 +11,8 @@ from mmengine.utils import deprecated_api_warning
|
|||
from torch import nn
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import LayerScale, MultiheadAttention, resize_pos_embed, to_2tuple
|
||||
from ..utils import (LayerScale, MultiheadAttention, build_norm_layer,
|
||||
resize_pos_embed, to_2tuple)
|
||||
from .vision_transformer import VisionTransformer
|
||||
|
||||
|
||||
|
@ -149,9 +150,7 @@ class DeiT3TransformerEncoderLayer(BaseModule):
|
|||
|
||||
self.embed_dims = embed_dims
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
|
||||
self.attn = MultiheadAttention(
|
||||
embed_dims=embed_dims,
|
||||
|
@ -162,9 +161,7 @@ class DeiT3TransformerEncoderLayer(BaseModule):
|
|||
qkv_bias=qkv_bias,
|
||||
use_layer_scale=use_layer_scale)
|
||||
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=2)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
|
||||
self.ffn = DeiT3FFN(
|
||||
embed_dims=embed_dims,
|
||||
|
@ -175,14 +172,6 @@ class DeiT3TransformerEncoderLayer(BaseModule):
|
|||
act_cfg=act_cfg,
|
||||
use_layer_scale=use_layer_scale)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def init_weights(self):
|
||||
super(DeiT3TransformerEncoderLayer, self).init_weights()
|
||||
for m in self.ffn.modules():
|
||||
|
@ -191,8 +180,8 @@ class DeiT3TransformerEncoderLayer(BaseModule):
|
|||
nn.init.normal_(m.bias, std=1e-6)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = self.ffn(self.norm2(x), identity=x)
|
||||
x = x + self.attn(self.ln1(x))
|
||||
x = self.ffn(self.ln1(x), identity=x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -237,10 +226,19 @@ class DeiT3(VisionTransformer):
|
|||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
out_type (str): The type of output features. Please choose from
|
||||
|
||||
- ``"cls_token"``: The class token tensor with shape (B, C).
|
||||
- ``"featmap"``: The feature map tensor from the patch tokens
|
||||
with shape (B, C, H, W).
|
||||
- ``"avg_featmap"``: The global averaged feature map tensor
|
||||
with shape (B, C).
|
||||
- ``"raw"``: The raw feature tensor includes patch tokens and
|
||||
class tokens with shape (B, L, C).
|
||||
|
||||
Defaults to ``"cls_token"``.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
use_layer_scale (bool): Whether to use layer_scale in DeiT3.
|
||||
Defaults to True.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
|
@ -288,9 +286,7 @@ class DeiT3(VisionTransformer):
|
|||
'feedforward_channels': 5120
|
||||
}),
|
||||
}
|
||||
# not using num_extra_tokens in deit3 because adding cls tokens after
|
||||
# adding pos_embed
|
||||
num_extra_tokens = 0
|
||||
num_extra_tokens = 1 # class token
|
||||
|
||||
def __init__(self,
|
||||
arch='base',
|
||||
|
@ -303,8 +299,8 @@ class DeiT3(VisionTransformer):
|
|||
qkv_bias=True,
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
final_norm=True,
|
||||
out_type='cls_token',
|
||||
with_cls_token=True,
|
||||
output_cls_token=True,
|
||||
use_layer_scale=True,
|
||||
interpolate_mode='bicubic',
|
||||
patch_cfg=dict(),
|
||||
|
@ -343,13 +339,21 @@ class DeiT3(VisionTransformer):
|
|||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
# Set out type
|
||||
if out_type not in self.OUT_TYPES:
|
||||
raise ValueError(f'Unsupported `out_type` {out_type}, please '
|
||||
f'choose from {self.OUT_TYPES}')
|
||||
self.out_type = out_type
|
||||
|
||||
# Set cls token
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
if with_cls_token:
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
elif out_type != 'cls_token':
|
||||
self.cls_token = None
|
||||
self.num_extra_tokens = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
'with_cls_token must be True when `out_type="cls_token"`.')
|
||||
|
||||
# Set position embedding
|
||||
self.interpolate_mode = interpolate_mode
|
||||
|
@ -393,9 +397,7 @@ class DeiT3(VisionTransformer):
|
|||
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
|
@ -406,38 +408,47 @@ class DeiT3(VisionTransformer):
|
|||
self.patch_resolution,
|
||||
patch_resolution,
|
||||
mode=self.interpolate_mode,
|
||||
num_extra_tokens=self.num_extra_tokens)
|
||||
num_extra_tokens=0)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
if self.cls_token is not None:
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
|
||||
if i == len(self.layers) - 1 and self.final_norm:
|
||||
x = self.norm1(x)
|
||||
x = self.ln1(x)
|
||||
|
||||
if i in self.out_indices:
|
||||
B, _, C = x.shape
|
||||
if self.with_cls_token:
|
||||
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
else:
|
||||
patch_token = x.reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = None
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token]
|
||||
else:
|
||||
out = patch_token
|
||||
outs.append(out)
|
||||
outs.append(self._format_output(x, patch_resolution))
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
|
||||
name = prefix + 'pos_embed'
|
||||
if name not in state_dict.keys():
|
||||
return
|
||||
|
||||
ckpt_pos_embed_shape = state_dict[name].shape
|
||||
if self.pos_embed.shape != ckpt_pos_embed_shape:
|
||||
from mmengine.logging import MMLogger
|
||||
logger = MMLogger.get_current_instance()
|
||||
logger.info(
|
||||
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
|
||||
f'to {self.pos_embed.shape}.')
|
||||
|
||||
ckpt_pos_embed_shape = to_2tuple(
|
||||
int(np.sqrt(ckpt_pos_embed_shape[1])))
|
||||
pos_embed_shape = self.patch_embed.init_out_size
|
||||
|
||||
state_dict[name] = resize_pos_embed(
|
||||
state_dict[name],
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
self.interpolate_mode,
|
||||
num_extra_tokens=0, # The cls token adding is after pos_embed
|
||||
)
|
||||
|
|
|
@ -196,7 +196,7 @@ class MixMIMBlock(TransformerEncoderLayer):
|
|||
B, L, C = x.shape
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = self.ln1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# partition windows
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import sys
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
|
@ -14,7 +12,8 @@ from torch.autograd import Function as Function
|
|||
|
||||
from mmpretrain.models.backbones.base_backbone import BaseBackbone
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
|
||||
from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed,
|
||||
to_2tuple)
|
||||
|
||||
|
||||
class RevBackProp(Function):
|
||||
|
@ -152,9 +151,7 @@ class RevTransformerEncoderLayer(BaseModule):
|
|||
self.drop_path_cfg = dict(type='DropPath', drop_prob=drop_path_rate)
|
||||
self.embed_dims = embed_dims
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
|
||||
self.attn = MultiheadAttention(
|
||||
embed_dims=embed_dims,
|
||||
|
@ -163,9 +160,7 @@ class RevTransformerEncoderLayer(BaseModule):
|
|||
proj_drop=drop_rate,
|
||||
qkv_bias=qkv_bias)
|
||||
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=2)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
|
@ -178,14 +173,6 @@ class RevTransformerEncoderLayer(BaseModule):
|
|||
self.layer_id = layer_id
|
||||
self.seeds = {}
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def init_weights(self):
|
||||
super(RevTransformerEncoderLayer, self).init_weights()
|
||||
for m in self.ffn.modules():
|
||||
|
@ -215,13 +202,13 @@ class RevTransformerEncoderLayer(BaseModule):
|
|||
Implementation of Reversible TransformerEncoderLayer
|
||||
|
||||
`
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = self.ffn(self.norm2(x), identity=x)
|
||||
x = x + self.attn(self.ln1(x))
|
||||
x = self.ffn(self.ln2(x), identity=x)
|
||||
`
|
||||
"""
|
||||
self.seed_cuda('attn')
|
||||
# attention output
|
||||
f_x2 = self.attn(self.norm1(x2))
|
||||
f_x2 = self.attn(self.ln1(x2))
|
||||
# apply droppath on attention output
|
||||
self.seed_cuda('droppath')
|
||||
f_x2_dropped = build_dropout(self.drop_path_cfg)(f_x2)
|
||||
|
@ -233,7 +220,7 @@ class RevTransformerEncoderLayer(BaseModule):
|
|||
|
||||
# ffn output
|
||||
self.seed_cuda('ffn')
|
||||
g_y1 = self.ffn(self.norm2(y1))
|
||||
g_y1 = self.ffn(self.ln2(y1))
|
||||
# apply droppath on ffn output
|
||||
torch.manual_seed(self.seeds['droppath'])
|
||||
g_y1_dropped = build_dropout(self.drop_path_cfg)(g_y1)
|
||||
|
@ -259,7 +246,7 @@ class RevTransformerEncoderLayer(BaseModule):
|
|||
y1.requires_grad = True
|
||||
|
||||
torch.manual_seed(self.seeds['ffn'])
|
||||
g_y1 = self.ffn(self.norm2(y1))
|
||||
g_y1 = self.ffn(self.ln2(y1))
|
||||
|
||||
torch.manual_seed(self.seeds['droppath'])
|
||||
g_y1 = build_dropout(self.drop_path_cfg)(g_y1)
|
||||
|
@ -280,7 +267,7 @@ class RevTransformerEncoderLayer(BaseModule):
|
|||
x2.requires_grad = True
|
||||
|
||||
torch.manual_seed(self.seeds['attn'])
|
||||
f_x2 = self.attn(self.norm1(x2))
|
||||
f_x2 = self.attn(self.ln1(x2))
|
||||
|
||||
torch.manual_seed(self.seeds['droppath'])
|
||||
f_x2 = build_dropout(self.drop_path_cfg)(f_x2)
|
||||
|
@ -337,7 +324,8 @@ class TwoStreamFusion(nn.Module):
|
|||
class RevVisionTransformer(BaseBackbone):
|
||||
"""Reversible Vision Transformer.
|
||||
|
||||
A PyTorch implementation of : `Reversible Vision Transformers <https://openaccess.thecvf.com/content/CVPR2022/papers/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.pdf>`_ # noqa: E501
|
||||
A PyTorch implementation of : `Reversible Vision Transformers
|
||||
<https://openaccess.thecvf.com/content/CVPR2022/html/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.html>`_ # noqa: E501
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture. If use string,
|
||||
|
@ -357,8 +345,6 @@ class RevVisionTransformer(BaseBackbone):
|
|||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 16.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
|
@ -368,15 +354,21 @@ class RevVisionTransformer(BaseBackbone):
|
|||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
out_type (str): The type of output features. Please choose from
|
||||
|
||||
- ``"cls_token"``: The class token tensor with shape (B, C).
|
||||
- ``"featmap"``: The feature map tensor from the patch tokens
|
||||
with shape (B, C, H, W).
|
||||
- ``"avg_featmap"``: The global averaged feature map tensor
|
||||
with shape (B, C).
|
||||
- ``"raw"``: The raw feature tensor includes patch tokens and
|
||||
class tokens with shape (B, L, C).
|
||||
|
||||
Defaults to ``"avg_featmap"``.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
avg_token (bool): Whether or not to use the mean patch token for
|
||||
classification. If True, the model will only take the average
|
||||
of all patch tokens. Defaults to False.
|
||||
tokens as transformer input. Defaults to False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
|
@ -443,24 +435,22 @@ class RevVisionTransformer(BaseBackbone):
|
|||
'feedforward_channels': 768 * 4
|
||||
}),
|
||||
}
|
||||
# Some structures have multiple extra tokens, like DeiT.
|
||||
num_extra_tokens = 1 # cls_token
|
||||
num_extra_tokens = 0 # The official RevViT doesn't have class token
|
||||
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
|
||||
|
||||
def __init__(self,
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
out_indices=-1,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
qkv_bias=True,
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
final_norm=True,
|
||||
out_type='avg_featmap',
|
||||
with_cls_token=False,
|
||||
avg_token=True,
|
||||
frozen_stages=-1,
|
||||
output_cls_token=False,
|
||||
interpolate_mode='bicubic',
|
||||
patch_cfg=dict(),
|
||||
layer_cfgs=dict(),
|
||||
|
@ -501,15 +491,22 @@ class RevVisionTransformer(BaseBackbone):
|
|||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
# Set cls token
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
self.with_cls_token = with_cls_token
|
||||
assert with_cls_token is False, 'with_cls_token=True is not supported'
|
||||
# Set out type
|
||||
if out_type not in self.OUT_TYPES:
|
||||
raise ValueError(f'Unsupported `out_type` {out_type}, please '
|
||||
f'choose from {self.OUT_TYPES}')
|
||||
self.out_type = out_type
|
||||
|
||||
self.output_cls_token = output_cls_token
|
||||
# Set cls token
|
||||
if with_cls_token:
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
self.num_extra_tokens = 1
|
||||
elif out_type != 'cls_token':
|
||||
self.cls_token = None
|
||||
self.num_extra_tokens = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
'with_cls_token must be True when `out_type="cls_token"`.')
|
||||
|
||||
# Set position embedding
|
||||
self.interpolate_mode = interpolate_mode
|
||||
|
@ -520,20 +517,6 @@ class RevVisionTransformer(BaseBackbone):
|
|||
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
out_indices = [out_indices]
|
||||
assert isinstance(out_indices, Sequence), \
|
||||
f'"out_indices" must by a sequence or int, ' \
|
||||
f'get {type(out_indices)} instead.'
|
||||
for i, index in enumerate(out_indices):
|
||||
if index < 0:
|
||||
out_indices[i] = self.num_layers + index
|
||||
assert 0 <= out_indices[i] <= self.num_layers, \
|
||||
f'Invalid out_indices {index}'
|
||||
self.out_indices = out_indices
|
||||
assert out_indices == [-1] or out_indices == [self.num_layers - 1], \
|
||||
f'only support output last layer current, but got {out_indices}'
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = np.linspace(0, drop_path_rate, self.num_layers)
|
||||
|
||||
|
@ -560,20 +543,12 @@ class RevVisionTransformer(BaseBackbone):
|
|||
self.frozen_stages = frozen_stages
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims * 2, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
self.avg_token = avg_token
|
||||
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims * 2)
|
||||
|
||||
# freeze stages only when self.frozen_stages > 0
|
||||
if self.frozen_stages > 0:
|
||||
self._freeze_stages()
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def init_weights(self):
|
||||
super(RevVisionTransformer, self).init_weights()
|
||||
if not (isinstance(self.init_cfg, dict)
|
||||
|
@ -618,7 +593,8 @@ class RevVisionTransformer(BaseBackbone):
|
|||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
# freeze cls_token
|
||||
# self.cls_token.requires_grad = False
|
||||
if self.cls_token is not None:
|
||||
self.cls_token.requires_grad = False
|
||||
# freeze layers
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
m = self.layers[i - 1]
|
||||
|
@ -627,17 +603,17 @@ class RevVisionTransformer(BaseBackbone):
|
|||
param.requires_grad = False
|
||||
# freeze the last layer norm
|
||||
if self.frozen_stages == len(self.layers) and self.final_norm:
|
||||
self.norm1.eval()
|
||||
for param in self.norm1.parameters():
|
||||
self.ln1.eval()
|
||||
for param in self.ln1.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
x, patch_resolution = self.patch_embed(x)
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.cls_token is not None:
|
||||
cls_token = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
|
||||
x = x + resize_pos_embed(
|
||||
self.pos_embed,
|
||||
|
@ -647,10 +623,6 @@ class RevVisionTransformer(BaseBackbone):
|
|||
num_extra_tokens=self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
x = torch.cat([x, x], dim=-1)
|
||||
|
||||
# forward with different conditions
|
||||
|
@ -664,33 +636,10 @@ class RevVisionTransformer(BaseBackbone):
|
|||
x = executing_fn(x, self.layers, [])
|
||||
|
||||
if self.final_norm:
|
||||
x = self.norm1(x)
|
||||
x = self.ln1(x)
|
||||
x = self.fusion_layer(x)
|
||||
|
||||
if self.with_cls_token:
|
||||
# RevViT does not allow cls_token
|
||||
raise NotImplementedError
|
||||
else:
|
||||
# (B, H, W, C)
|
||||
_, __, C = x.shape
|
||||
patch_token = x.reshape(B, *patch_resolution, C)
|
||||
# (B, C, H, W)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = None
|
||||
|
||||
if self.avg_token:
|
||||
# (B, H, W, C)
|
||||
patch_token = patch_token.permute(0, 2, 3, 1)
|
||||
# (B, L, C) -> (B, C)
|
||||
patch_token = patch_token.reshape(
|
||||
B, patch_resolution[0] * patch_resolution[1], C).mean(dim=1)
|
||||
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token]
|
||||
else:
|
||||
out = patch_token
|
||||
|
||||
return tuple([out])
|
||||
return (self._format_output(x, patch_resolution), )
|
||||
|
||||
@staticmethod
|
||||
def _forward_vanilla_bp(hidden_state, layers, buffer=[]):
|
||||
|
@ -706,3 +655,17 @@ class RevVisionTransformer(BaseBackbone):
|
|||
attn_out, ffn_out = layer(attn_out, ffn_out)
|
||||
|
||||
return torch.cat([attn_out, ffn_out], dim=-1)
|
||||
|
||||
def _format_output(self, x, hw):
|
||||
if self.out_type == 'raw':
|
||||
return x
|
||||
if self.out_type == 'cls_token':
|
||||
return x[:, 0]
|
||||
|
||||
patch_token = x[:, self.num_extra_tokens:]
|
||||
if self.out_type == 'featmap':
|
||||
B = x.size(0)
|
||||
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
|
||||
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
|
||||
if self.out_type == 'avg_featmap':
|
||||
return patch_token.mean(dim=1)
|
||||
|
|
|
@ -5,13 +5,13 @@ from typing import Sequence
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
|
||||
from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed,
|
||||
to_2tuple)
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
|
@ -70,9 +70,7 @@ class T2TTransformerLayer(BaseModule):
|
|||
self.v_shortcut = True if input_dims is not None else False
|
||||
input_dims = input_dims or embed_dims
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, input_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.ln1 = build_norm_layer(norm_cfg, input_dims)
|
||||
|
||||
self.attn = MultiheadAttention(
|
||||
input_dims=input_dims,
|
||||
|
@ -85,9 +83,7 @@ class T2TTransformerLayer(BaseModule):
|
|||
qk_scale=qk_scale or (input_dims // num_heads)**-0.5,
|
||||
v_shortcut=self.v_shortcut)
|
||||
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=2)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.ln2 = build_norm_layer(norm_cfg, embed_dims)
|
||||
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
|
@ -97,20 +93,12 @@ class T2TTransformerLayer(BaseModule):
|
|||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def forward(self, x):
|
||||
if self.v_shortcut:
|
||||
x = self.attn(self.norm1(x))
|
||||
x = self.attn(self.ln1(x))
|
||||
else:
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = self.ffn(self.norm2(x), identity=x)
|
||||
x = x + self.attn(self.ln1(x))
|
||||
x = self.ffn(self.ln2(x), identity=x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -265,10 +253,19 @@ class T2T_ViT(BaseBackbone):
|
|||
``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
out_type (str): The type of output features. Please choose from
|
||||
|
||||
- ``"cls_token"``: The class token tensor with shape (B, C).
|
||||
- ``"featmap"``: The feature map tensor from the patch tokens
|
||||
with shape (B, C, H, W).
|
||||
- ``"avg_featmap"``: The global averaged feature map tensor
|
||||
with shape (B, C).
|
||||
- ``"raw"``: The raw feature tensor includes patch tokens and
|
||||
class tokens with shape (B, L, C).
|
||||
|
||||
Defaults to ``"cls_token"``.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
t2t_cfg (dict): Extra config of Tokens-to-Token module.
|
||||
|
@ -278,7 +275,7 @@ class T2T_ViT(BaseBackbone):
|
|||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
num_extra_tokens = 1 # cls_token
|
||||
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
|
@ -290,13 +287,13 @@ class T2T_ViT(BaseBackbone):
|
|||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
final_norm=True,
|
||||
out_type='cls_token',
|
||||
with_cls_token=True,
|
||||
output_cls_token=True,
|
||||
interpolate_mode='bicubic',
|
||||
t2t_cfg=dict(),
|
||||
layer_cfgs=dict(),
|
||||
init_cfg=None):
|
||||
super(T2T_ViT, self).__init__(init_cfg)
|
||||
super().__init__(init_cfg)
|
||||
|
||||
# Token-to-Token Module
|
||||
self.tokens_to_token = T2TModule(
|
||||
|
@ -307,13 +304,22 @@ class T2T_ViT(BaseBackbone):
|
|||
self.patch_resolution = self.tokens_to_token.init_out_size
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
# Set out type
|
||||
if out_type not in self.OUT_TYPES:
|
||||
raise ValueError(f'Unsupported `out_type` {out_type}, please '
|
||||
f'choose from {self.OUT_TYPES}')
|
||||
self.out_type = out_type
|
||||
|
||||
# Set cls token
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
if with_cls_token:
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
self.num_extra_tokens = 1
|
||||
elif out_type != 'cls_token':
|
||||
self.cls_token = None
|
||||
self.num_extra_tokens = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
'with_cls_token must be True when `out_type="cls_token"`.')
|
||||
|
||||
# Set position embedding
|
||||
self.interpolate_mode = interpolate_mode
|
||||
|
@ -360,7 +366,7 @@ class T2T_ViT(BaseBackbone):
|
|||
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dims)
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
|
||||
|
@ -401,9 +407,10 @@ class T2T_ViT(BaseBackbone):
|
|||
B = x.shape[0]
|
||||
x, patch_resolution = self.tokens_to_token(x)
|
||||
|
||||
if self.cls_token is not None:
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
cls_token = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
|
||||
x = x + resize_pos_embed(
|
||||
self.pos_embed,
|
||||
|
@ -413,10 +420,6 @@ class T2T_ViT(BaseBackbone):
|
|||
num_extra_tokens=self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.encoder):
|
||||
x = layer(x)
|
||||
|
@ -425,19 +428,20 @@ class T2T_ViT(BaseBackbone):
|
|||
x = self.norm(x)
|
||||
|
||||
if i in self.out_indices:
|
||||
B, _, C = x.shape
|
||||
if self.with_cls_token:
|
||||
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
else:
|
||||
patch_token = x.reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = None
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token]
|
||||
else:
|
||||
out = patch_token
|
||||
outs.append(out)
|
||||
outs.append(self._format_output(x, patch_resolution))
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def _format_output(self, x, hw):
|
||||
if self.out_type == 'raw':
|
||||
return x
|
||||
if self.out_type == 'cls_token':
|
||||
return x[:, 0]
|
||||
|
||||
patch_token = x[:, self.num_extra_tokens:]
|
||||
if self.out_type == 'featmap':
|
||||
B = x.size(0)
|
||||
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
|
||||
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
|
||||
if self.out_type == 'avg_featmap':
|
||||
return patch_token.mean(dim=1)
|
||||
|
|
|
@ -4,13 +4,13 @@ from typing import Sequence
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
|
||||
from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed,
|
||||
to_2tuple)
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
|
@ -53,9 +53,7 @@ class TransformerEncoderLayer(BaseModule):
|
|||
|
||||
self.embed_dims = embed_dims
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
|
||||
self.attn = MultiheadAttention(
|
||||
embed_dims=embed_dims,
|
||||
|
@ -65,9 +63,7 @@ class TransformerEncoderLayer(BaseModule):
|
|||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
qkv_bias=qkv_bias)
|
||||
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=2)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
|
@ -79,11 +75,11 @@ class TransformerEncoderLayer(BaseModule):
|
|||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
return self.ln1
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
return getattr(self, self.norm2_name)
|
||||
return self.ln2
|
||||
|
||||
def init_weights(self):
|
||||
super(TransformerEncoderLayer, self).init_weights()
|
||||
|
@ -93,8 +89,8 @@ class TransformerEncoderLayer(BaseModule):
|
|||
nn.init.normal_(m.bias, std=1e-6)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = self.ffn(self.norm2(x), identity=x)
|
||||
x = x + self.attn(self.ln1(x))
|
||||
x = self.ffn(self.ln2(x), identity=x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -134,15 +130,21 @@ class VisionTransformer(BaseBackbone):
|
|||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
out_type (str): The type of output features. Please choose from
|
||||
|
||||
- ``"cls_token"``: The class token tensor with shape (B, C).
|
||||
- ``"featmap"``: The feature map tensor from the patch tokens
|
||||
with shape (B, C, H, W).
|
||||
- ``"avg_featmap"``: The global averaged feature map tensor
|
||||
with shape (B, C).
|
||||
- ``"raw"``: The raw feature tensor includes patch tokens and
|
||||
class tokens with shape (B, L, C).
|
||||
|
||||
Defaults to ``"cls_token"``.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
avg_token (bool): Whether or not to use the mean patch token for
|
||||
classification. If True, the model will only take the average
|
||||
of all patch tokens. Defaults to False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
|
@ -215,8 +217,8 @@ class VisionTransformer(BaseBackbone):
|
|||
'feedforward_channels': 768 * 4
|
||||
}),
|
||||
}
|
||||
# Some structures have multiple extra tokens, like DeiT.
|
||||
num_extra_tokens = 1 # cls_token
|
||||
num_extra_tokens = 1 # class token
|
||||
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
|
||||
|
||||
def __init__(self,
|
||||
arch='base',
|
||||
|
@ -229,10 +231,9 @@ class VisionTransformer(BaseBackbone):
|
|||
qkv_bias=True,
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
final_norm=True,
|
||||
out_type='cls_token',
|
||||
with_cls_token=True,
|
||||
avg_token=False,
|
||||
frozen_stages=-1,
|
||||
output_cls_token=True,
|
||||
interpolate_mode='bicubic',
|
||||
patch_cfg=dict(),
|
||||
layer_cfgs=dict(),
|
||||
|
@ -272,13 +273,21 @@ class VisionTransformer(BaseBackbone):
|
|||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
# Set out type
|
||||
if out_type not in self.OUT_TYPES:
|
||||
raise ValueError(f'Unsupported `out_type` {out_type}, please '
|
||||
f'choose from {self.OUT_TYPES}')
|
||||
self.out_type = out_type
|
||||
|
||||
# Set cls token
|
||||
if output_cls_token:
|
||||
assert with_cls_token is True, f'with_cls_token must be True if' \
|
||||
f'set output_cls_token to True, but got {with_cls_token}'
|
||||
self.with_cls_token = with_cls_token
|
||||
self.output_cls_token = output_cls_token
|
||||
if with_cls_token:
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
elif out_type != 'cls_token':
|
||||
self.cls_token = None
|
||||
self.num_extra_tokens = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
'with_cls_token must be True when `out_type="cls_token"`.')
|
||||
|
||||
# Set position embedding
|
||||
self.interpolate_mode = interpolate_mode
|
||||
|
@ -322,34 +331,25 @@ class VisionTransformer(BaseBackbone):
|
|||
|
||||
self.frozen_stages = frozen_stages
|
||||
if pre_norm:
|
||||
_, norm_layer = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=1)
|
||||
self.pre_norm = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
else:
|
||||
norm_layer = nn.Identity()
|
||||
self.add_module('pre_norm', norm_layer)
|
||||
self.pre_norm = nn.Identity()
|
||||
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
|
||||
self.avg_token = avg_token
|
||||
if avg_token:
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
norm_cfg, self.embed_dims, postfix=2)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
# freeze stages only when self.frozen_stages > 0
|
||||
if self.frozen_stages > 0:
|
||||
self._freeze_stages()
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
return self.ln1
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
return getattr(self, self.norm2_name)
|
||||
return self.ln2
|
||||
|
||||
def init_weights(self):
|
||||
super(VisionTransformer, self).init_weights()
|
||||
|
@ -407,17 +407,19 @@ class VisionTransformer(BaseBackbone):
|
|||
param.requires_grad = False
|
||||
# freeze the last layer norm
|
||||
if self.frozen_stages == len(self.layers) and self.final_norm:
|
||||
self.norm1.eval()
|
||||
for param in self.norm1.parameters():
|
||||
self.ln1.eval()
|
||||
for param in self.ln1.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
x, patch_resolution = self.patch_embed(x)
|
||||
|
||||
if self.cls_token is not None:
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
cls_token = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
|
||||
x = x + resize_pos_embed(
|
||||
self.pos_embed,
|
||||
self.patch_resolution,
|
||||
|
@ -427,41 +429,33 @@ class VisionTransformer(BaseBackbone):
|
|||
x = self.drop_after_pos(x)
|
||||
|
||||
x = self.pre_norm(x)
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
|
||||
if i == len(self.layers) - 1 and self.final_norm:
|
||||
x = self.norm1(x)
|
||||
x = self.ln1(x)
|
||||
|
||||
if i in self.out_indices:
|
||||
B, _, C = x.shape
|
||||
if self.with_cls_token:
|
||||
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
else:
|
||||
patch_token = x.reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = None
|
||||
if self.avg_token:
|
||||
patch_token = patch_token.permute(0, 2, 3, 1)
|
||||
patch_token = patch_token.reshape(
|
||||
B, patch_resolution[0] * patch_resolution[1],
|
||||
C).mean(dim=1)
|
||||
patch_token = self.norm2(patch_token)
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token]
|
||||
else:
|
||||
out = patch_token
|
||||
outs.append(out)
|
||||
outs.append(self._format_output(x, patch_resolution))
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def _format_output(self, x, hw):
|
||||
if self.out_type == 'raw':
|
||||
return x
|
||||
if self.out_type == 'cls_token':
|
||||
return x[:, 0]
|
||||
|
||||
patch_token = x[:, self.num_extra_tokens:]
|
||||
if self.out_type == 'featmap':
|
||||
B = x.size(0)
|
||||
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
|
||||
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
|
||||
if self.out_type == 'avg_featmap':
|
||||
return patch_token.mean(dim=1)
|
||||
|
||||
def get_layer_depth(self, param_name: str, prefix: str = ''):
|
||||
"""Get the layer-wise depth of a parameter.
|
||||
|
||||
|
|
|
@ -47,7 +47,12 @@ class DeiTClsHead(VisionTransformerClsHead):
|
|||
the feature of a backbone stage. In ``DeiTClsHead``, we obtain the
|
||||
feature of the last stage and forward in hidden layer if exists.
|
||||
"""
|
||||
_, cls_token, dist_token = feats[-1]
|
||||
feat = feats[-1] # Obtain feature of the last scale.
|
||||
# For backward-compatibility with the previous ViT output
|
||||
if len(feat) == 3:
|
||||
_, cls_token, dist_token = feat
|
||||
else:
|
||||
cls_token, dist_token = feat
|
||||
if self.hidden_dim is None:
|
||||
return cls_token, dist_token
|
||||
else:
|
||||
|
|
|
@ -80,7 +80,9 @@ class VisionTransformerClsHead(ClsHead):
|
|||
obtain the feature of the last stage and forward in hidden layer if
|
||||
exists.
|
||||
"""
|
||||
_, cls_token = feats[-1]
|
||||
feat = feats[-1] # Obtain feature of the last scale.
|
||||
# For backward-compatibility with the previous ViT output
|
||||
cls_token = feat[-1] if isinstance(feat, list) else feat
|
||||
if self.hidden_dim is None:
|
||||
return cls_token
|
||||
else:
|
||||
|
|
|
@ -183,7 +183,6 @@ class ClsBatchNormNeck(BaseModule):
|
|||
self,
|
||||
inputs: Tuple[List[torch.Tensor]]) -> Tuple[List[torch.Tensor]]:
|
||||
"""The forward function."""
|
||||
# Only apply batch norm to cls token, which is the second tensor in
|
||||
# each item of the tuple
|
||||
inputs = [[input_[0], self.bn(input_[1])] for input_ in inputs]
|
||||
# Only apply batch norm to cls_token
|
||||
inputs = [self.bn(input_) for input_ in inputs]
|
||||
return tuple(inputs)
|
||||
|
|
|
@ -32,8 +32,6 @@ class NonLinearNeck(BaseModule):
|
|||
Defaults to False.
|
||||
with_avg_pool (bool): Whether to apply the global average pooling
|
||||
after backbone. Defaults to True.
|
||||
vit_backbone (bool): The key to indicate whether the upstream backbone
|
||||
is ViT. Defaults to False.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Defaults to dict(type='SyncBN').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
|
@ -50,7 +48,6 @@ class NonLinearNeck(BaseModule):
|
|||
with_last_bn_affine: bool = True,
|
||||
with_last_bias: bool = False,
|
||||
with_avg_pool: bool = True,
|
||||
vit_backbone: bool = False,
|
||||
norm_cfg: dict = dict(type='SyncBN'),
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = [
|
||||
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
|
@ -58,7 +55,6 @@ class NonLinearNeck(BaseModule):
|
|||
) -> None:
|
||||
super(NonLinearNeck, self).__init__(init_cfg)
|
||||
self.with_avg_pool = with_avg_pool
|
||||
self.vit_backbone = vit_backbone
|
||||
if with_avg_pool:
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
@ -104,8 +100,6 @@ class NonLinearNeck(BaseModule):
|
|||
"""
|
||||
assert len(x) == 1
|
||||
x = x[0]
|
||||
if self.vit_backbone:
|
||||
x = x[-1]
|
||||
if self.with_avg_pool:
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
|
|
|
@ -140,15 +140,21 @@ class BEiTPretrainViT(BEiTViT):
|
|||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
out_type (str): The type of output features. Please choose from
|
||||
|
||||
- ``"cls_token"``: The class token tensor with shape (B, C).
|
||||
- ``"featmap"``: The feature map tensor from the patch tokens
|
||||
with shape (B, C, H, W).
|
||||
- ``"avg_featmap"``: The global averaged feature map tensor
|
||||
with shape (B, C).
|
||||
- ``"raw"``: The raw feature tensor includes patch tokens and
|
||||
class tokens with shape (B, L, C).
|
||||
|
||||
It only works without input mask. Defaults to ``"avg_featmap"``.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
avg_token (bool): Whether or not to use the mean patch token for
|
||||
classification. If True, the model will only take the average
|
||||
of all patch tokens. Defaults to False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
use_abs_pos_emb (bool): Whether or not use absolute position embedding.
|
||||
Defaults to False.
|
||||
use_rel_pos_bias (bool): Whether or not use relative position bias.
|
||||
|
@ -176,9 +182,8 @@ class BEiTPretrainViT(BEiTViT):
|
|||
drop_path_rate: float = 0,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
final_norm: bool = True,
|
||||
avg_token: bool = False,
|
||||
out_type: str = 'avg_featmap',
|
||||
frozen_stages: int = -1,
|
||||
output_cls_token: bool = True,
|
||||
use_abs_pos_emb: bool = False,
|
||||
use_rel_pos_bias: bool = False,
|
||||
use_shared_rel_pos_bias: bool = True,
|
||||
|
@ -197,9 +202,9 @@ class BEiTPretrainViT(BEiTViT):
|
|||
drop_path_rate=drop_path_rate,
|
||||
norm_cfg=norm_cfg,
|
||||
final_norm=final_norm,
|
||||
avg_token=avg_token,
|
||||
out_type=out_type,
|
||||
with_cls_token=True,
|
||||
frozen_stages=frozen_stages,
|
||||
output_cls_token=output_cls_token,
|
||||
use_abs_pos_emb=use_abs_pos_emb,
|
||||
use_shared_rel_pos_bias=use_shared_rel_pos_bias,
|
||||
use_rel_pos_bias=use_rel_pos_bias,
|
||||
|
|
|
@ -217,8 +217,17 @@ class CAEPretrainViT(BEiTViT):
|
|||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Defaults to True.
|
||||
out_type (str): The type of output features. Please choose from
|
||||
|
||||
- ``"cls_token"``: The class token tensor with shape (B, C).
|
||||
- ``"featmap"``: The feature map tensor from the patch tokens
|
||||
with shape (B, C, H, W).
|
||||
- ``"avg_featmap"``: The global averaged feature map tensor
|
||||
with shape (B, C).
|
||||
- ``"raw"``: The raw feature tensor includes patch tokens and
|
||||
class tokens with shape (B, L, C).
|
||||
|
||||
It only works without input mask. Defaults to ``"avg_featmap"``.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
layer_scale_init_value (float, optional): The init value of gamma in
|
||||
|
@ -242,10 +251,8 @@ class CAEPretrainViT(BEiTViT):
|
|||
bias: bool = 'qv_bias',
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
final_norm: bool = True,
|
||||
with_cls_token: bool = True,
|
||||
avg_token: bool = False,
|
||||
out_type: str = 'avg_featmap',
|
||||
frozen_stages: int = -1,
|
||||
output_cls_token: bool = True,
|
||||
use_abs_pos_emb: bool = True,
|
||||
use_rel_pos_bias: bool = False,
|
||||
use_shared_rel_pos_bias: bool = False,
|
||||
|
@ -270,10 +277,9 @@ class CAEPretrainViT(BEiTViT):
|
|||
bias=bias,
|
||||
norm_cfg=norm_cfg,
|
||||
final_norm=final_norm,
|
||||
with_cls_token=with_cls_token,
|
||||
avg_token=avg_token,
|
||||
out_type=out_type,
|
||||
with_cls_token=True,
|
||||
frozen_stages=frozen_stages,
|
||||
output_cls_token=output_cls_token,
|
||||
use_abs_pos_emb=use_abs_pos_emb,
|
||||
use_rel_pos_bias=use_rel_pos_bias,
|
||||
use_shared_rel_pos_bias=use_shared_rel_pos_bias,
|
||||
|
|
|
@ -33,8 +33,17 @@ class MAEViT(VisionTransformer):
|
|||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Defaults to True.
|
||||
out_type (str): The type of output features. Please choose from
|
||||
|
||||
- ``"cls_token"``: The class token tensor with shape (B, C).
|
||||
- ``"featmap"``: The feature map tensor from the patch tokens
|
||||
with shape (B, C, H, W).
|
||||
- ``"avg_featmap"``: The global averaged feature map tensor
|
||||
with shape (B, C).
|
||||
- ``"raw"``: The raw feature tensor includes patch tokens and
|
||||
class tokens with shape (B, L, C).
|
||||
|
||||
It only works without input mask. Defaults to ``"avg_featmap"``.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
|
@ -55,7 +64,7 @@ class MAEViT(VisionTransformer):
|
|||
drop_path_rate: float = 0,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
final_norm: bool = True,
|
||||
output_cls_token: bool = True,
|
||||
out_type: str = 'avg_featmap',
|
||||
interpolate_mode: str = 'bicubic',
|
||||
patch_cfg: dict = dict(),
|
||||
layer_cfgs: dict = dict(),
|
||||
|
@ -70,7 +79,8 @@ class MAEViT(VisionTransformer):
|
|||
drop_path_rate=drop_path_rate,
|
||||
norm_cfg=norm_cfg,
|
||||
final_norm=final_norm,
|
||||
output_cls_token=output_cls_token,
|
||||
out_type=out_type,
|
||||
with_cls_token=True,
|
||||
interpolate_mode=interpolate_mode,
|
||||
patch_cfg=patch_cfg,
|
||||
layer_cfgs=layer_cfgs,
|
||||
|
|
|
@ -178,8 +178,17 @@ class MaskFeatViT(VisionTransformer):
|
|||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Defaults to True.
|
||||
out_type (str): The type of output features. Please choose from
|
||||
|
||||
- ``"cls_token"``: The class token tensor with shape (B, C).
|
||||
- ``"featmap"``: The feature map tensor from the patch tokens
|
||||
with shape (B, C, H, W).
|
||||
- ``"avg_featmap"``: The global averaged feature map tensor
|
||||
with shape (B, C).
|
||||
- ``"raw"``: The raw feature tensor includes patch tokens and
|
||||
class tokens with shape (B, L, C).
|
||||
|
||||
It only works without input mask. Defaults to ``"avg_featmap"``.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
|
@ -198,7 +207,7 @@ class MaskFeatViT(VisionTransformer):
|
|||
drop_path_rate: float = 0,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
final_norm: bool = True,
|
||||
output_cls_token: bool = True,
|
||||
out_type: str = 'avg_featmap',
|
||||
interpolate_mode: str = 'bicubic',
|
||||
patch_cfg: dict = dict(),
|
||||
layer_cfgs: dict = dict(),
|
||||
|
@ -212,7 +221,8 @@ class MaskFeatViT(VisionTransformer):
|
|||
drop_path_rate=drop_path_rate,
|
||||
norm_cfg=norm_cfg,
|
||||
final_norm=final_norm,
|
||||
output_cls_token=output_cls_token,
|
||||
out_type=out_type,
|
||||
with_cls_token=True,
|
||||
interpolate_mode=interpolate_mode,
|
||||
patch_cfg=patch_cfg,
|
||||
layer_cfgs=layer_cfgs,
|
||||
|
|
|
@ -8,6 +8,14 @@ import torch.nn as nn
|
|||
from mmengine.model import BaseModule
|
||||
from mmengine.utils import digit_version
|
||||
|
||||
# After pytorch v1.10.0, use torch.meshgrid without indexing
|
||||
# will raise extra warning. For more details,
|
||||
# refers to https://github.com/pytorch/pytorch/issues/50276
|
||||
if digit_version(torch.__version__) >= digit_version('1.10.0'):
|
||||
torch_meshgrid = partial(torch.meshgrid, indexing='ij')
|
||||
else:
|
||||
torch_meshgrid = torch.meshgrid
|
||||
|
||||
|
||||
class ConditionalPositionEncoding(BaseModule):
|
||||
"""The Conditional Position Encoding (CPE) module.
|
||||
|
@ -137,7 +145,7 @@ def build_2d_sincos_position_embedding(
|
|||
h, w = patches_resolution
|
||||
grid_w = torch.arange(w, dtype=torch.float32)
|
||||
grid_h = torch.arange(h, dtype=torch.float32)
|
||||
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
|
||||
grid_w, grid_h = torch_meshgrid(grid_w, grid_h)
|
||||
assert embed_dims % 4 == 0, \
|
||||
'Embed dimension must be divisible by 4.'
|
||||
pos_dim = embed_dims // 4
|
||||
|
|
|
@ -82,18 +82,16 @@ class TestBEiT(TestCase):
|
|||
def test_forward(self):
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
|
||||
# test with output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['output_cls_token'] = True
|
||||
cfg['out_type'] = 'cls_token'
|
||||
model = BEiTViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (1, 768))
|
||||
cls_token = outs[-1]
|
||||
self.assertEqual(cls_token.shape, (1, 768))
|
||||
|
||||
# test without output_cls_token
|
||||
# test without output cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = BEiTViT(**cfg)
|
||||
outs = model(imgs)
|
||||
|
@ -104,7 +102,7 @@ class TestBEiT(TestCase):
|
|||
|
||||
# test without average
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['avg_token'] = False
|
||||
cfg['out_type'] = 'featmap'
|
||||
model = BEiTViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
|
|
|
@ -72,7 +72,6 @@ class TestCSPDarkNet(TestCase):
|
|||
def test_forward(self):
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
|
||||
# test without output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = self.class_name(**cfg)
|
||||
outs = model(imgs)
|
||||
|
|
|
@ -62,37 +62,19 @@ class TestDeiT(TestCase):
|
|||
def test_forward(self):
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
|
||||
# test with_cls_token=False
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = True
|
||||
with self.assertRaisesRegex(AssertionError, 'but got False'):
|
||||
DistilledVisionTransformer(**cfg)
|
||||
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = False
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (1, 192, 14, 14))
|
||||
|
||||
# test with output_cls_token
|
||||
# test with output cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token, dist_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (1, 192, 14, 14))
|
||||
cls_token, dist_token = outs[-1]
|
||||
self.assertEqual(cls_token.shape, (1, 192))
|
||||
self.assertEqual(dist_token.shape, (1, 192))
|
||||
|
||||
# test without output_cls_token
|
||||
# test without output cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['output_cls_token'] = False
|
||||
cfg['out_type'] = 'featmap'
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
|
@ -108,8 +90,7 @@ class TestDeiT(TestCase):
|
|||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 3)
|
||||
for out in outs:
|
||||
patch_token, cls_token, dist_token = out
|
||||
self.assertEqual(patch_token.shape, (1, 192, 14, 14))
|
||||
cls_token, dist_token = out
|
||||
self.assertEqual(cls_token.shape, (1, 192))
|
||||
self.assertEqual(dist_token.shape, (1, 192))
|
||||
|
||||
|
@ -118,14 +99,13 @@ class TestDeiT(TestCase):
|
|||
imgs2 = torch.randn(1, 3, 256, 256)
|
||||
imgs3 = torch.randn(1, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_type'] = 'featmap'
|
||||
model = DistilledVisionTransformer(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token, dist_token = outs[-1]
|
||||
featmap = outs[-1]
|
||||
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
|
||||
math.ceil(imgs.shape[3] / 16))
|
||||
self.assertEqual(patch_token.shape, (1, 192, *expect_feat_shape))
|
||||
self.assertEqual(cls_token.shape, (1, 192))
|
||||
self.assertEqual(dist_token.shape, (1, 192))
|
||||
self.assertEqual(featmap.shape, (1, 192, *expect_feat_shape))
|
||||
|
|
|
@ -116,13 +116,13 @@ class TestDeiT3(TestCase):
|
|||
# test with_cls_token=False
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = True
|
||||
with self.assertRaisesRegex(AssertionError, 'but got False'):
|
||||
cfg['out_type'] = 'cls_token'
|
||||
with self.assertRaisesRegex(ValueError, 'must be True'):
|
||||
DeiT3(**cfg)
|
||||
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = False
|
||||
cfg['out_type'] = 'featmap'
|
||||
model = DeiT3(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
|
@ -130,26 +130,15 @@ class TestDeiT3(TestCase):
|
|||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
|
||||
|
||||
# test with output_cls_token
|
||||
# test with output cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = DeiT3(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
|
||||
cls_token = outs[-1]
|
||||
self.assertEqual(cls_token.shape, (1, 768))
|
||||
|
||||
# test without output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['output_cls_token'] = False
|
||||
model = DeiT3(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
|
||||
|
||||
# Test forward with multi out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
|
@ -158,8 +147,7 @@ class TestDeiT3(TestCase):
|
|||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 3)
|
||||
for out in outs:
|
||||
patch_token, cls_token = out
|
||||
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
|
||||
cls_token = out
|
||||
self.assertEqual(cls_token.shape, (1, 768))
|
||||
|
||||
# Test forward with dynamic input size
|
||||
|
@ -167,13 +155,13 @@ class TestDeiT3(TestCase):
|
|||
imgs2 = torch.randn(1, 3, 256, 256)
|
||||
imgs3 = torch.randn(1, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_type'] = 'featmap'
|
||||
model = DeiT3(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
featmap = outs[-1]
|
||||
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
|
||||
math.ceil(imgs.shape[3] / 16))
|
||||
self.assertEqual(patch_token.shape, (1, 768, *expect_feat_shape))
|
||||
self.assertEqual(cls_token.shape, (1, 768))
|
||||
self.assertEqual(featmap.shape, (1, 768, *expect_feat_shape))
|
||||
|
|
|
@ -49,16 +49,6 @@ class TestRevVisionTransformer(TestCase):
|
|||
self.assertEqual(layer.attn.num_heads, 16)
|
||||
self.assertEqual(layer.ffn.feedforward_channels, 1024)
|
||||
|
||||
# Test out_indices
|
||||
# TODO: to be implemented, current only support last layer
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = {1: 1}
|
||||
with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
|
||||
RevVisionTransformer(**cfg)
|
||||
cfg['out_indices'] = [13]
|
||||
with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'):
|
||||
RevVisionTransformer(**cfg)
|
||||
|
||||
# Test model structure
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = RevVisionTransformer(**cfg)
|
||||
|
@ -108,8 +98,8 @@ class TestRevVisionTransformer(TestCase):
|
|||
cfg['img_size'] = 384
|
||||
model = RevVisionTransformer(**cfg)
|
||||
load_checkpoint(model, checkpoint, strict=True)
|
||||
resized_pos_embed = timm_resize_pos_embed(pretrain_pos_embed,
|
||||
model.pos_embed)
|
||||
resized_pos_embed = timm_resize_pos_embed(
|
||||
pretrain_pos_embed, model.pos_embed, num_tokens=0)
|
||||
self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed))
|
||||
|
||||
os.remove(checkpoint)
|
||||
|
@ -119,7 +109,7 @@ class TestRevVisionTransformer(TestCase):
|
|||
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = False
|
||||
cfg['out_type'] = 'avg_featmap'
|
||||
model = RevVisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
|
@ -137,5 +127,5 @@ class TestRevVisionTransformer(TestCase):
|
|||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
avg_token = outs[-1]
|
||||
self.assertEqual(avg_token.shape, (1, 768 * 2))
|
||||
avg_featmap = outs[-1]
|
||||
self.assertEqual(avg_featmap.shape, (1, 768 * 2))
|
||||
|
|
|
@ -94,13 +94,13 @@ class TestT2TViT(TestCase):
|
|||
# test with_cls_token=False
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = True
|
||||
with self.assertRaisesRegex(AssertionError, 'but got False'):
|
||||
cfg['out_type'] = 'cls_token'
|
||||
with self.assertRaisesRegex(ValueError, 'must be True'):
|
||||
T2T_ViT(**cfg)
|
||||
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = False
|
||||
cfg['out_type'] = 'featmap'
|
||||
model = T2T_ViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
|
@ -108,26 +108,15 @@ class TestT2TViT(TestCase):
|
|||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
|
||||
|
||||
# test with output_cls_token
|
||||
# test with output cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = T2T_ViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
|
||||
cls_token = outs[-1]
|
||||
self.assertEqual(cls_token.shape, (1, 384))
|
||||
|
||||
# test without output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['output_cls_token'] = False
|
||||
model = T2T_ViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
|
||||
|
||||
# Test forward with multi out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
|
@ -136,22 +125,20 @@ class TestT2TViT(TestCase):
|
|||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 3)
|
||||
for out in outs:
|
||||
patch_token, cls_token = out
|
||||
self.assertEqual(patch_token.shape, (1, 384, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (1, 384))
|
||||
self.assertEqual(out.shape, (1, 384))
|
||||
|
||||
# Test forward with dynamic input size
|
||||
imgs1 = torch.randn(1, 3, 224, 224)
|
||||
imgs2 = torch.randn(1, 3, 256, 256)
|
||||
imgs3 = torch.randn(1, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_type'] = 'featmap'
|
||||
model = T2T_ViT(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
patch_token = outs[-1]
|
||||
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
|
||||
math.ceil(imgs.shape[3] / 16))
|
||||
self.assertEqual(patch_token.shape, (1, 384, *expect_feat_shape))
|
||||
self.assertEqual(cls_token.shape, (1, 384))
|
||||
|
|
|
@ -126,13 +126,13 @@ class TestVisionTransformer(TestCase):
|
|||
# test with_cls_token=False
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = True
|
||||
with self.assertRaisesRegex(AssertionError, 'but got False'):
|
||||
cfg['out_type'] = 'cls_token'
|
||||
with self.assertRaisesRegex(ValueError, 'must be True'):
|
||||
VisionTransformer(**cfg)
|
||||
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['with_cls_token'] = False
|
||||
cfg['output_cls_token'] = False
|
||||
cfg['out_type'] = 'featmap'
|
||||
model = VisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
|
@ -140,26 +140,15 @@ class TestVisionTransformer(TestCase):
|
|||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
|
||||
|
||||
# test with output_cls_token
|
||||
# test with output cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = VisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
|
||||
cls_token = outs[-1]
|
||||
self.assertEqual(cls_token.shape, (1, 768))
|
||||
|
||||
# test without output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['output_cls_token'] = False
|
||||
model = VisionTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token = outs[-1]
|
||||
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
|
||||
|
||||
# Test forward with multi out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
|
@ -168,22 +157,20 @@ class TestVisionTransformer(TestCase):
|
|||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 3)
|
||||
for out in outs:
|
||||
patch_token, cls_token = out
|
||||
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
|
||||
self.assertEqual(cls_token.shape, (1, 768))
|
||||
self.assertEqual(out.shape, (1, 768))
|
||||
|
||||
# Test forward with dynamic input size
|
||||
imgs1 = torch.randn(1, 3, 224, 224)
|
||||
imgs2 = torch.randn(1, 3, 256, 256)
|
||||
imgs3 = torch.randn(1, 3, 256, 309)
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_type'] = 'featmap'
|
||||
model = VisionTransformer(**cfg)
|
||||
for imgs in [imgs1, imgs2, imgs3]:
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
patch_token, cls_token = outs[-1]
|
||||
patch_token = outs[-1]
|
||||
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
|
||||
math.ceil(imgs.shape[3] / 16))
|
||||
self.assertEqual(patch_token.shape, (1, 768, *expect_feat_shape))
|
||||
self.assertEqual(cls_token.shape, (1, 768))
|
||||
|
|
|
@ -35,9 +35,7 @@ class TestBEiT(TestCase):
|
|||
|
||||
# test without mask
|
||||
fake_outputs = beit_backbone(fake_inputs, None)
|
||||
assert len(fake_outputs[0]) == 2
|
||||
assert fake_outputs[0][0].shape == torch.Size([2, 768, 14, 14])
|
||||
assert fake_outputs[0][1].shape == torch.Size([2, 768])
|
||||
assert fake_outputs[0].shape == torch.Size([2, 768])
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == 'Windows', reason='Windows mem limit')
|
||||
|
@ -111,10 +109,9 @@ class TestBEiT(TestCase):
|
|||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
final_norm=True,
|
||||
out_type='featmap',
|
||||
with_cls_token=True,
|
||||
avg_token=False,
|
||||
frozen_stages=-1,
|
||||
output_cls_token=False,
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
|
|
|
@ -25,9 +25,7 @@ def test_cae_vit():
|
|||
|
||||
# test without mask
|
||||
fake_outputs = cae_backbone(fake_inputs, None)
|
||||
assert len(fake_outputs[0]) == 2
|
||||
assert fake_outputs[0][0].shape == torch.Size([1, 192, 14, 14])
|
||||
assert fake_outputs[0][1].shape == torch.Size([1, 192])
|
||||
assert fake_outputs[0].shape == torch.Size([1, 192])
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
|
|
|
@ -21,9 +21,7 @@ def test_mae_vit():
|
|||
|
||||
# test without mask
|
||||
fake_outputs = mae_backbone(fake_inputs, None)
|
||||
assert len(fake_outputs[0]) == 2
|
||||
assert fake_outputs[0][0].shape == torch.Size([2, 768, 14, 14])
|
||||
assert fake_outputs[0][1].shape == torch.Size([2, 768])
|
||||
assert fake_outputs[0].shape == torch.Size([2, 768])
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
|
|
|
@ -22,9 +22,7 @@ def test_maskfeat_vit():
|
|||
|
||||
# test without mask
|
||||
fake_outputs = maskfeat_backbone(fake_inputs, None)
|
||||
assert len(fake_outputs[0]) == 2
|
||||
assert fake_outputs[0][0].shape == torch.Size([2, 768, 14, 14])
|
||||
assert fake_outputs[0][1].shape == torch.Size([2, 768])
|
||||
assert fake_outputs[0].shape == torch.Size([2, 768])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
|
|
@ -24,9 +24,7 @@ def test_milan_vit():
|
|||
|
||||
# test without mask
|
||||
fake_outputs = milan_backbone(fake_inputs, None)
|
||||
assert len(fake_outputs[0]) == 2
|
||||
assert fake_outputs[0][0].shape == torch.Size([2, 768, 14, 14])
|
||||
assert fake_outputs[0][1].shape == torch.Size([2, 768])
|
||||
assert fake_outputs[0].shape == torch.Size([2, 768])
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
|
|
|
@ -29,7 +29,6 @@ class TestMoCoV3(TestCase):
|
|||
with_last_bn_affine=False,
|
||||
with_last_bias=False,
|
||||
with_avg_pool=False,
|
||||
vit_backbone=True,
|
||||
norm_cfg=dict(type='BN1d'))
|
||||
head = dict(
|
||||
type='MoCoV3Head',
|
||||
|
@ -89,5 +88,4 @@ class TestMoCoV3(TestCase):
|
|||
|
||||
# test extract
|
||||
fake_feats = alg(fake_inputs['inputs'][0], mode='tensor')
|
||||
self.assertEqual(fake_feats[0][0].size(), torch.Size([2, 384, 14, 14]))
|
||||
self.assertEqual(fake_feats[0][1].size(), torch.Size([2, 384]))
|
||||
self.assertEqual(fake_feats[0].size(), torch.Size([2, 384]))
|
||||
|
|
|
@ -50,10 +50,9 @@ class TestVQKD(TestCase):
|
|||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
final_norm=True,
|
||||
out_type='featmap',
|
||||
with_cls_token=True,
|
||||
avg_token=False,
|
||||
frozen_stages=-1,
|
||||
output_cls_token=False,
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
|
|
Loading…
Reference in New Issue