From dbf3df21a31ee5404184b3f1ba0588b6b6aaa8f4 Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Thu, 9 Mar 2023 11:02:58 +0800 Subject: [PATCH] [Refactor] Use `out_type` to specify ViT-like backbone output. (#1408) * [Refactor] Use to specify ViT-like backbone output. * Fix ClsBatchNormNeck * Update mmpretrain/models/necks/mae_neck.py --------- Co-authored-by: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com> --- .../datasets/imagenet_bs128_revvit_224.py | 2 +- configs/_base_/models/eva/eva-g.py | 3 +- configs/_base_/models/eva/eva-l.py | 3 +- configs/_base_/models/revvit/revvit-base.py | 4 +- configs/_base_/models/revvit/revvit-small.py | 4 +- .../beit-base-p16_8xb128-coslr-100e_in1k.py | 3 +- .../benchmarks/beit-base-p16_8xb64_in1k.py | 3 +- ...it-base-p16_8xb256-amp-coslr-1600e_in1k.py | 3 +- ...eit-base-p16_8xb256-amp-coslr-300e_in1k.py | 3 +- .../beit-base-p16_8xb128-coslr-100e_in1k.py | 3 +- .../benchmarks/beit-base-p16_8xb64_in1k.py | 3 +- .../beit-base-p16_8xb128-coslr-100e_in1k.py | 3 +- .../vit-base-p16_8xb128-coslr-100e_in1k.py | 3 +- ...base-p16_8xb2048-linear-coslr-100e_in1k.py | 2 +- configs/eva/eva-g-p14_headless.py | 3 +- configs/eva/eva-g-p16_headless.py | 3 +- configs/eva/eva-l-p14_headless.py | 3 +- .../vit-base-p16_8xb128-coslr-100e_in1k.py | 3 +- ...-base-p16_8xb2048-linear-coslr-90e_in1k.py | 2 +- ...vit-huge-p14_32xb8-coslr-50e_in1k-448px.py | 3 +- .../vit-huge-p14_8xb128-coslr-50e_in1k.py | 3 +- .../vit-large-p16_8xb128-coslr-50e_in1k.py | 3 +- ...large-p16_8xb2048-linear-coslr-90e_in1k.py | 2 +- .../vit-base-p16_8xb256-coslr-100e_in1k.py | 3 +- .../vit-base-p16_8xb128-coslr-100e_in1k.py | 3 +- ...base-p16_8xb2048-linear-coslr-100e_in1k.py | 2 +- ...ov3_resnet50_8xb512-amp-coslr-100e_in1k.py | 3 +- ...ov3_resnet50_8xb512-amp-coslr-300e_in1k.py | 3 +- ...ov3_resnet50_8xb512-amp-coslr-800e_in1k.py | 3 +- ...it-base-p16_16xb256-amp-coslr-300e_in1k.py | 3 +- ...it-large-p16_64xb64-amp-coslr-300e_in1k.py | 3 +- ...t-small-p16_16xb256-amp-coslr-300e_in1k.py | 3 +- configs/revvit/metafile.yml | 4 +- configs/tsne/vit-base-p16_imagenet.py | 3 +- mmpretrain/models/backbones/beit.py | 114 ++++++------ mmpretrain/models/backbones/deit.py | 53 +++--- mmpretrain/models/backbones/deit3.py | 123 +++++++------ mmpretrain/models/backbones/mixmim.py | 2 +- mmpretrain/models/backbones/revvit.py | 169 +++++++----------- mmpretrain/models/backbones/t2t_vit.py | 108 +++++------ .../models/backbones/vision_transformer.py | 134 +++++++------- mmpretrain/models/heads/deit_head.py | 7 +- .../models/heads/vision_transformer_head.py | 4 +- mmpretrain/models/necks/mae_neck.py | 5 +- mmpretrain/models/necks/nonlinear_neck.py | 6 - mmpretrain/models/selfsup/beit.py | 23 ++- mmpretrain/models/selfsup/cae.py | 22 ++- mmpretrain/models/selfsup/mae.py | 18 +- mmpretrain/models/selfsup/maskfeat.py | 18 +- mmpretrain/models/utils/position_encoding.py | 10 +- tests/test_models/test_backbones/test_beit.py | 10 +- .../test_models/test_backbones/test_cspnet.py | 1 - tests/test_models/test_backbones/test_deit.py | 36 +--- .../test_models/test_backbones/test_deit3.py | 30 +--- .../test_models/test_backbones/test_revvit.py | 20 +-- .../test_backbones/test_t2t_vit.py | 29 +-- .../test_backbones/test_vision_transformer.py | 29 +-- tests/test_models/test_selfsup/test_beit.py | 7 +- tests/test_models/test_selfsup/test_cae.py | 4 +- tests/test_models/test_selfsup/test_mae.py | 4 +- .../test_models/test_selfsup/test_maskfeat.py | 4 +- tests/test_models/test_selfsup/test_milan.py | 4 +- tests/test_models/test_selfsup/test_mocov3.py | 4 +- .../test_selfsup/test_target_generators.py | 3 +- 64 files changed, 497 insertions(+), 604 deletions(-) diff --git a/configs/_base_/datasets/imagenet_bs128_revvit_224.py b/configs/_base_/datasets/imagenet_bs128_revvit_224.py index 821cfaf4..1b5ad7c8 100644 --- a/configs/_base_/datasets/imagenet_bs128_revvit_224.py +++ b/configs/_base_/datasets/imagenet_bs128_revvit_224.py @@ -72,7 +72,7 @@ val_dataloader = dict( dataset=dict( type=dataset_type, data_root='data/imagenet', - # ann_file='meta/val.txt', + ann_file='meta/val.txt', data_prefix='val', pipeline=test_pipeline), sampler=dict(type='DefaultSampler', shuffle=False), diff --git a/configs/_base_/models/eva/eva-g.py b/configs/_base_/models/eva/eva-g.py index 91a3a43d..17bc84ad 100644 --- a/configs/_base_/models/eva/eva-g.py +++ b/configs/_base_/models/eva/eva-g.py @@ -5,9 +5,8 @@ model = dict( arch='eva-g', img_size=224, patch_size=14, - avg_token=True, layer_scale_init_value=0.0, - output_cls_token=False, + out_type='avg_featmap', use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, diff --git a/configs/_base_/models/eva/eva-l.py b/configs/_base_/models/eva/eva-l.py index 67405f60..9b08e4b1 100644 --- a/configs/_base_/models/eva/eva-l.py +++ b/configs/_base_/models/eva/eva-l.py @@ -5,9 +5,8 @@ model = dict( arch='l', img_size=224, patch_size=14, - avg_token=True, layer_scale_init_value=0.0, - output_cls_token=False, + out_type='avg_featmap', use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, diff --git a/configs/_base_/models/revvit/revvit-base.py b/configs/_base_/models/revvit/revvit-base.py index 354498ed..85b7af42 100644 --- a/configs/_base_/models/revvit/revvit-base.py +++ b/configs/_base_/models/revvit/revvit-base.py @@ -6,9 +6,7 @@ model = dict( arch='deit-base', img_size=224, patch_size=16, - output_cls_token=False, - avg_token=True, - with_cls_token=False, + out_type='avg_featmap', ), neck=None, head=dict( diff --git a/configs/_base_/models/revvit/revvit-small.py b/configs/_base_/models/revvit/revvit-small.py index 6d43781a..dd1a0b26 100644 --- a/configs/_base_/models/revvit/revvit-small.py +++ b/configs/_base_/models/revvit/revvit-small.py @@ -6,9 +6,7 @@ model = dict( arch='deit-small', img_size=224, patch_size=16, - output_cls_token=False, - avg_token=True, - with_cls_token=False, + out_type='avg_featmap', ), neck=None, head=dict( diff --git a/configs/beit/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py b/configs/beit/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py index 3e6b5d75..e7c54379 100644 --- a/configs/beit/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py +++ b/configs/beit/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py @@ -21,8 +21,7 @@ model = dict( img_size=224, patch_size=16, drop_path_rate=0.1, - avg_token=True, - output_cls_token=False, + out_type='avg_featmap', use_abs_pos_emb=False, use_rel_pos_bias=True, use_shared_rel_pos_bias=False), diff --git a/configs/beit/benchmarks/beit-base-p16_8xb64_in1k.py b/configs/beit/benchmarks/beit-base-p16_8xb64_in1k.py index 7fbc82e2..8380b69a 100644 --- a/configs/beit/benchmarks/beit-base-p16_8xb64_in1k.py +++ b/configs/beit/benchmarks/beit-base-p16_8xb64_in1k.py @@ -20,8 +20,7 @@ model = dict( arch='base', img_size=224, patch_size=16, - avg_token=True, - output_cls_token=False, + out_type='avg_featmap', use_abs_pos_emb=False, use_rel_pos_bias=True, use_shared_rel_pos_bias=False, diff --git a/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-1600e_in1k.py b/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-1600e_in1k.py index 389c715e..2ee1fa72 100644 --- a/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-1600e_in1k.py +++ b/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-1600e_in1k.py @@ -14,10 +14,9 @@ vqkd_encoder = dict( drop_path_rate=0., norm_cfg=dict(type='LN', eps=1e-6), final_norm=True, + out_type='featmap', with_cls_token=True, - avg_token=False, frozen_stages=-1, - output_cls_token=False, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, diff --git a/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-300e_in1k.py b/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-300e_in1k.py index fd5dfb57..fcfe6393 100644 --- a/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-300e_in1k.py +++ b/configs/beitv2/beitv2_beit-base-p16_8xb256-amp-coslr-300e_in1k.py @@ -14,10 +14,9 @@ vqkd_encoder = dict( drop_path_rate=0., norm_cfg=dict(type='LN', eps=1e-6), final_norm=True, + out_type='featmap', with_cls_token=True, - avg_token=False, frozen_stages=-1, - output_cls_token=False, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, diff --git a/configs/beitv2/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py b/configs/beitv2/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py index e58b7d65..ee8fabec 100644 --- a/configs/beitv2/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py +++ b/configs/beitv2/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py @@ -14,8 +14,7 @@ model = dict( patch_size=16, # 0.2 for 1600 epochs pretrained models and 0.1 for 300 epochs. drop_path_rate=0.1, - avg_token=True, - output_cls_token=False, + out_type='avg_featmap', use_abs_pos_emb=False, use_rel_pos_bias=True, use_shared_rel_pos_bias=False), diff --git a/configs/beitv2/benchmarks/beit-base-p16_8xb64_in1k.py b/configs/beitv2/benchmarks/beit-base-p16_8xb64_in1k.py index 2b937019..17ed4ff3 100644 --- a/configs/beitv2/benchmarks/beit-base-p16_8xb64_in1k.py +++ b/configs/beitv2/benchmarks/beit-base-p16_8xb64_in1k.py @@ -11,8 +11,7 @@ model = dict( arch='base', img_size=224, patch_size=16, - avg_token=True, - output_cls_token=False, + out_type='avg_featmap', use_abs_pos_emb=False, use_rel_pos_bias=True, use_shared_rel_pos_bias=False, diff --git a/configs/cae/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py b/configs/cae/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py index f3114aab..67f2d373 100644 --- a/configs/cae/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py +++ b/configs/cae/benchmarks/beit-base-p16_8xb128-coslr-100e_in1k.py @@ -67,11 +67,10 @@ model = dict( arch='base', img_size=224, patch_size=16, - avg_token=True, # use average token for cls head final_norm=False, # do not use final norm drop_path_rate=0.1, layer_scale_init_value=0.1, - output_cls_token=False, + out_type='avg_featmap', use_abs_pos_emb=True, use_rel_pos_bias=True, use_shared_rel_pos_bias=False, diff --git a/configs/eva/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py b/configs/eva/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py index 3210aace..a90185de 100644 --- a/configs/eva/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py +++ b/configs/eva/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py @@ -56,8 +56,7 @@ model = dict( img_size=224, patch_size=16, drop_path_rate=0.1, - avg_token=True, - output_cls_token=False, + out_type='avg_featmap', final_norm=False, init_cfg=dict(type='Pretrained', checkpoint='')), neck=None, diff --git a/configs/eva/benchmarks/vit-base-p16_8xb2048-linear-coslr-100e_in1k.py b/configs/eva/benchmarks/vit-base-p16_8xb2048-linear-coslr-100e_in1k.py index e1661e31..20db6e56 100644 --- a/configs/eva/benchmarks/vit-base-p16_8xb2048-linear-coslr-100e_in1k.py +++ b/configs/eva/benchmarks/vit-base-p16_8xb2048-linear-coslr-100e_in1k.py @@ -17,7 +17,7 @@ model = dict( img_size=224, patch_size=16, frozen_stages=12, - avg_token=False, + out_type='cls_token', final_norm=True, init_cfg=dict(type='Pretrained', checkpoint='')), neck=dict(type='ClsBatchNormNeck', input_features=768), diff --git a/configs/eva/eva-g-p14_headless.py b/configs/eva/eva-g-p14_headless.py index 8d408507..b278acea 100644 --- a/configs/eva/eva-g-p14_headless.py +++ b/configs/eva/eva-g-p14_headless.py @@ -5,9 +5,8 @@ model = dict( arch='eva-g', img_size=224, patch_size=14, - avg_token=True, layer_scale_init_value=0.0, - output_cls_token=False, + out_type='avg_featmap', use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, diff --git a/configs/eva/eva-g-p16_headless.py b/configs/eva/eva-g-p16_headless.py index b326ce29..ca5de186 100644 --- a/configs/eva/eva-g-p16_headless.py +++ b/configs/eva/eva-g-p16_headless.py @@ -5,9 +5,8 @@ model = dict( arch='eva-g', img_size=224, patch_size=16, - avg_token=True, layer_scale_init_value=0.0, - output_cls_token=False, + out_type='avg_featmap', use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, diff --git a/configs/eva/eva-l-p14_headless.py b/configs/eva/eva-l-p14_headless.py index 6d3c95ba..89a4ce10 100644 --- a/configs/eva/eva-l-p14_headless.py +++ b/configs/eva/eva-l-p14_headless.py @@ -5,9 +5,8 @@ model = dict( arch='l', img_size=224, patch_size=14, - avg_token=True, layer_scale_init_value=0.0, - output_cls_token=False, + out_type='avg_featmap', use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, diff --git a/configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py b/configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py index 8ae1e444..51181562 100644 --- a/configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py +++ b/configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py @@ -55,8 +55,7 @@ model = dict( img_size=224, patch_size=16, drop_path_rate=0.1, - avg_token=True, - output_cls_token=False, + out_type='avg_featmap', final_norm=False, init_cfg=dict(type='Pretrained', checkpoint='')), neck=None, diff --git a/configs/mae/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py b/configs/mae/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py index 90f9a596..39b14e57 100644 --- a/configs/mae/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py +++ b/configs/mae/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py @@ -18,7 +18,7 @@ model = dict( img_size=224, patch_size=16, frozen_stages=12, - avg_token=False, + out_type='cls_token', final_norm=True, init_cfg=dict(type='Pretrained', checkpoint='')), neck=dict(type='ClsBatchNormNeck', input_features=768), diff --git a/configs/mae/benchmarks/vit-huge-p14_32xb8-coslr-50e_in1k-448px.py b/configs/mae/benchmarks/vit-huge-p14_32xb8-coslr-50e_in1k-448px.py index a61e2810..9ac9dbb4 100644 --- a/configs/mae/benchmarks/vit-huge-p14_32xb8-coslr-50e_in1k-448px.py +++ b/configs/mae/benchmarks/vit-huge-p14_32xb8-coslr-50e_in1k-448px.py @@ -57,8 +57,7 @@ model = dict( img_size=448, patch_size=14, drop_path_rate=0.3, # set to 0.3 - avg_token=True, - output_cls_token=False, + out_type='avg_featmap', final_norm=False, init_cfg=dict(type='Pretrained', checkpoint='')), neck=None, diff --git a/configs/mae/benchmarks/vit-huge-p14_8xb128-coslr-50e_in1k.py b/configs/mae/benchmarks/vit-huge-p14_8xb128-coslr-50e_in1k.py index d643dfbe..71ba3696 100644 --- a/configs/mae/benchmarks/vit-huge-p14_8xb128-coslr-50e_in1k.py +++ b/configs/mae/benchmarks/vit-huge-p14_8xb128-coslr-50e_in1k.py @@ -56,8 +56,7 @@ model = dict( img_size=224, patch_size=14, drop_path_rate=0.3, # set to 0.3 - avg_token=True, - output_cls_token=False, + out_type='avg_featmap', final_norm=False, init_cfg=dict(type='Pretrained', checkpoint='')), neck=None, diff --git a/configs/mae/benchmarks/vit-large-p16_8xb128-coslr-50e_in1k.py b/configs/mae/benchmarks/vit-large-p16_8xb128-coslr-50e_in1k.py index 77398dff..9ad6909a 100644 --- a/configs/mae/benchmarks/vit-large-p16_8xb128-coslr-50e_in1k.py +++ b/configs/mae/benchmarks/vit-large-p16_8xb128-coslr-50e_in1k.py @@ -56,8 +56,7 @@ model = dict( img_size=224, patch_size=16, drop_path_rate=0.2, # set to 0.2 - avg_token=True, - output_cls_token=False, + out_type='avg_featmap', final_norm=False, init_cfg=dict(type='Pretrained', checkpoint='')), neck=None, diff --git a/configs/mae/benchmarks/vit-large-p16_8xb2048-linear-coslr-90e_in1k.py b/configs/mae/benchmarks/vit-large-p16_8xb2048-linear-coslr-90e_in1k.py index 145f3dcf..8b363fd5 100644 --- a/configs/mae/benchmarks/vit-large-p16_8xb2048-linear-coslr-90e_in1k.py +++ b/configs/mae/benchmarks/vit-large-p16_8xb2048-linear-coslr-90e_in1k.py @@ -18,7 +18,7 @@ model = dict( img_size=224, patch_size=16, frozen_stages=24, - avg_token=False, + out_type='cls_token', final_norm=True, init_cfg=dict(type='Pretrained', checkpoint='')), neck=dict(type='ClsBatchNormNeck', input_features=1024), diff --git a/configs/maskfeat/benchmarks/vit-base-p16_8xb256-coslr-100e_in1k.py b/configs/maskfeat/benchmarks/vit-base-p16_8xb256-coslr-100e_in1k.py index 9d781b66..29722b34 100644 --- a/configs/maskfeat/benchmarks/vit-base-p16_8xb256-coslr-100e_in1k.py +++ b/configs/maskfeat/benchmarks/vit-base-p16_8xb256-coslr-100e_in1k.py @@ -54,8 +54,7 @@ model = dict( img_size=224, patch_size=16, drop_path_rate=0.1, - avg_token=True, - output_cls_token=False, + out_type='avg_featmap', final_norm=False, init_cfg=dict(type='Pretrained', checkpoint='')), neck=None, diff --git a/configs/milan/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py b/configs/milan/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py index 3210aace..a90185de 100644 --- a/configs/milan/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py +++ b/configs/milan/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py @@ -56,8 +56,7 @@ model = dict( img_size=224, patch_size=16, drop_path_rate=0.1, - avg_token=True, - output_cls_token=False, + out_type='avg_featmap', final_norm=False, init_cfg=dict(type='Pretrained', checkpoint='')), neck=None, diff --git a/configs/milan/benchmarks/vit-base-p16_8xb2048-linear-coslr-100e_in1k.py b/configs/milan/benchmarks/vit-base-p16_8xb2048-linear-coslr-100e_in1k.py index e1661e31..20db6e56 100644 --- a/configs/milan/benchmarks/vit-base-p16_8xb2048-linear-coslr-100e_in1k.py +++ b/configs/milan/benchmarks/vit-base-p16_8xb2048-linear-coslr-100e_in1k.py @@ -17,7 +17,7 @@ model = dict( img_size=224, patch_size=16, frozen_stages=12, - avg_token=False, + out_type='cls_token', final_norm=True, init_cfg=dict(type='Pretrained', checkpoint='')), neck=dict(type='ClsBatchNormNeck', input_features=768), diff --git a/configs/mocov3/mocov3_resnet50_8xb512-amp-coslr-100e_in1k.py b/configs/mocov3/mocov3_resnet50_8xb512-amp-coslr-100e_in1k.py index da8f5d50..419cda17 100644 --- a/configs/mocov3/mocov3_resnet50_8xb512-amp-coslr-100e_in1k.py +++ b/configs/mocov3/mocov3_resnet50_8xb512-amp-coslr-100e_in1k.py @@ -19,8 +19,7 @@ model = dict( with_last_bn=True, with_last_bn_affine=False, with_last_bias=False, - with_avg_pool=True, - vit_backbone=False), + with_avg_pool=True), head=dict( type='MoCoV3Head', predictor=dict( diff --git a/configs/mocov3/mocov3_resnet50_8xb512-amp-coslr-300e_in1k.py b/configs/mocov3/mocov3_resnet50_8xb512-amp-coslr-300e_in1k.py index 0f9a76d9..8f7e5970 100644 --- a/configs/mocov3/mocov3_resnet50_8xb512-amp-coslr-300e_in1k.py +++ b/configs/mocov3/mocov3_resnet50_8xb512-amp-coslr-300e_in1k.py @@ -19,8 +19,7 @@ model = dict( with_last_bn=True, with_last_bn_affine=False, with_last_bias=False, - with_avg_pool=True, - vit_backbone=False), + with_avg_pool=True), head=dict( type='MoCoV3Head', predictor=dict( diff --git a/configs/mocov3/mocov3_resnet50_8xb512-amp-coslr-800e_in1k.py b/configs/mocov3/mocov3_resnet50_8xb512-amp-coslr-800e_in1k.py index fd151b4b..ac59c920 100644 --- a/configs/mocov3/mocov3_resnet50_8xb512-amp-coslr-800e_in1k.py +++ b/configs/mocov3/mocov3_resnet50_8xb512-amp-coslr-800e_in1k.py @@ -19,8 +19,7 @@ model = dict( with_last_bn=True, with_last_bn_affine=False, with_last_bias=False, - with_avg_pool=True, - vit_backbone=False), + with_avg_pool=True), head=dict( type='MoCoV3Head', predictor=dict( diff --git a/configs/mocov3/mocov3_vit-base-p16_16xb256-amp-coslr-300e_in1k.py b/configs/mocov3/mocov3_vit-base-p16_16xb256-amp-coslr-300e_in1k.py index 19f01d7d..6b18fda7 100644 --- a/configs/mocov3/mocov3_vit-base-p16_16xb256-amp-coslr-300e_in1k.py +++ b/configs/mocov3/mocov3_vit-base-p16_16xb256-amp-coslr-300e_in1k.py @@ -99,8 +99,7 @@ model = dict( with_last_bn=True, with_last_bn_affine=False, with_last_bias=False, - with_avg_pool=False, - vit_backbone=True), + with_avg_pool=False), head=dict( type='MoCoV3Head', predictor=dict( diff --git a/configs/mocov3/mocov3_vit-large-p16_64xb64-amp-coslr-300e_in1k.py b/configs/mocov3/mocov3_vit-large-p16_64xb64-amp-coslr-300e_in1k.py index 32012044..ae31c6d8 100644 --- a/configs/mocov3/mocov3_vit-large-p16_64xb64-amp-coslr-300e_in1k.py +++ b/configs/mocov3/mocov3_vit-large-p16_64xb64-amp-coslr-300e_in1k.py @@ -99,8 +99,7 @@ model = dict( with_last_bn=True, with_last_bn_affine=False, with_last_bias=False, - with_avg_pool=False, - vit_backbone=True), + with_avg_pool=False), head=dict( type='MoCoV3Head', predictor=dict( diff --git a/configs/mocov3/mocov3_vit-small-p16_16xb256-amp-coslr-300e_in1k.py b/configs/mocov3/mocov3_vit-small-p16_16xb256-amp-coslr-300e_in1k.py index 3fc3ad3a..0d26eec7 100644 --- a/configs/mocov3/mocov3_vit-small-p16_16xb256-amp-coslr-300e_in1k.py +++ b/configs/mocov3/mocov3_vit-small-p16_16xb256-amp-coslr-300e_in1k.py @@ -99,8 +99,7 @@ model = dict( with_last_bn=True, with_last_bn_affine=False, with_last_bias=False, - with_avg_pool=False, - vit_backbone=True), + with_avg_pool=False), head=dict( type='MoCoV3Head', predictor=dict( diff --git a/configs/revvit/metafile.yml b/configs/revvit/metafile.yml index 6b5b5818..7c3eb4d1 100644 --- a/configs/revvit/metafile.yml +++ b/configs/revvit/metafile.yml @@ -25,7 +25,7 @@ Models: Top 1 Accuracy: 79.87 Top 5 Accuracy: 94.90 Task: Image Classification - Weights: https://download.openmmlab.com/mmclassification/v0/revvit/revvit-base_3rdparty_in1k_20221213-87a7b0a5.pth + Weights: https://download.openmmlab.com/mmclassification/v0/revvit/revvit-small_3rdparty_in1k_20221213-a3a34f5c.pth Config: configs/revvit/revvit-small_8xb256_in1k.py Converted From: Weights: https://dl.fbaipublicfiles.com/pyslowfast/rev/REV_VIT_S.pyth @@ -41,7 +41,7 @@ Models: Top 1 Accuracy: 81.81 Top 5 Accuracy: 95.56 Task: Image Classification - Weights: https://download.openmmlab.com/mmclassification/v0/revvit/revvit-small_3rdparty_in1k_20221213-a3a34f5c.pth + Weights: https://download.openmmlab.com/mmclassification/v0/revvit/revvit-base_3rdparty_in1k_20221213-87a7b0a5.pth Config: configs/revvit/revvit-base_8xb256_in1k.py Converted From: Weights: https://dl.fbaipublicfiles.com/pyslowfast/rev/REV_VIT_B.pyth diff --git a/configs/tsne/vit-base-p16_imagenet.py b/configs/tsne/vit-base-p16_imagenet.py index 38d542a5..609ddf3d 100644 --- a/configs/tsne/vit-base-p16_imagenet.py +++ b/configs/tsne/vit-base-p16_imagenet.py @@ -9,8 +9,7 @@ model = dict( patch_size=16, out_indices=-1, drop_path_rate=0.1, - avg_token=False, - output_cls_token=False, + out_type='featmap', final_norm=False), neck=None, head=dict( diff --git a/mmpretrain/models/backbones/beit.py b/mmpretrain/models/backbones/beit.py index 6c7227b9..8f64ae20 100644 --- a/mmpretrain/models/backbones/beit.py +++ b/mmpretrain/models/backbones/beit.py @@ -4,13 +4,12 @@ from typing import List, Optional, Sequence, Tuple, Union import numpy as np import torch import torch.nn as nn -from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.bricks.transformer import FFN, PatchEmbed from mmengine.model import BaseModule, ModuleList from mmpretrain.registry import MODELS -from ..utils import (BEiTAttention, resize_pos_embed, +from ..utils import (BEiTAttention, build_norm_layer, resize_pos_embed, resize_relative_position_bias_table, to_2tuple) from .vision_transformer import TransformerEncoderLayer, VisionTransformer @@ -203,12 +202,12 @@ class BEiTTransformerEncoderLayer(TransformerEncoderLayer): rel_pos_bias: torch.Tensor) -> torch.Tensor: if self.gamma_1 is None: x = x + self.drop_path( - self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) - x = x + self.drop_path(self.ffn(self.norm2(x))) + self.attn(self.ln1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.ffn(self.ln2(x))) else: x = x + self.drop_path(self.gamma_1 * self.attn( - self.norm1(x), rel_pos_bias=rel_pos_bias)) - x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x))) + self.ln1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.ffn(self.ln2(x))) return x @@ -251,15 +250,21 @@ class BEiTViT(VisionTransformer): Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"avg_featmap"``. with_cls_token (bool): Whether concatenating class token into image tokens as transformer input. Defaults to True. - avg_token (bool): Whether or not to use the mean patch token for - classification. If True, the model will only take the average - of all patch tokens. Defaults to False. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. - output_cls_token (bool): Whether output the cls_token. If set True, - ``with_cls_token`` must be True. Defaults to True. use_abs_pos_emb (bool): Use position embedding like vanilla ViT. Defaults to False. use_rel_pos_bias (bool): Use relative position embedding in each @@ -289,10 +294,9 @@ class BEiTViT(VisionTransformer): bias='qv_bias', norm_cfg=dict(type='LN', eps=1e-6), final_norm=False, + out_type='avg_featmap', with_cls_token=True, - avg_token=True, frozen_stages=-1, - output_cls_token=False, use_abs_pos_emb=False, use_rel_pos_bias=True, use_shared_rel_pos_bias=False, @@ -334,17 +338,25 @@ class BEiTViT(VisionTransformer): self.patch_resolution = self.patch_embed.init_out_size num_patches = self.patch_resolution[0] * self.patch_resolution[1] - # Set cls token - if output_cls_token: - assert with_cls_token is True, f'with_cls_token must be True if' \ - f'set output_cls_token to True, but got {with_cls_token}' - self.with_cls_token = with_cls_token - self.output_cls_token = output_cls_token - self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type - self.interpolate_mode = interpolate_mode + # Set cls token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + self.num_extra_tokens = 1 + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') # Set position embedding + self.interpolate_mode = interpolate_mode if use_abs_pos_emb: self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + self.num_extra_tokens, @@ -405,15 +417,10 @@ class BEiTViT(VisionTransformer): self.frozen_stages = frozen_stages self.final_norm = final_norm if final_norm: - self.norm1_name, norm1 = build_norm_layer( - norm_cfg, self.embed_dims, postfix=1) - self.add_module(self.norm1_name, norm1) + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) - self.avg_token = avg_token - if avg_token: - self.norm2_name, norm2 = build_norm_layer( - norm_cfg, self.embed_dims, postfix=2) - self.add_module(self.norm2_name, norm2) + if out_type == 'avg_featmap': + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) # freeze stages only when self.frozen_stages > 0 if self.frozen_stages > 0: @@ -423,9 +430,10 @@ class BEiTViT(VisionTransformer): B = x.shape[0] x, patch_resolution = self.patch_embed(x) - # stole cls_tokens impl from Phil Wang, thanks - cls_tokens = self.cls_token.expand(B, -1, -1) - x = torch.cat((cls_tokens, x), dim=1) + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) if self.pos_embed is not None: x = x + resize_pos_embed( @@ -439,42 +447,32 @@ class BEiTViT(VisionTransformer): rel_pos_bias = self.rel_pos_bias() \ if self.rel_pos_bias is not None else None - if not self.with_cls_token: - # Remove class token for transformer encoder input - x = x[:, 1:] - outs = [] for i, layer in enumerate(self.layers): x = layer(x, rel_pos_bias) if i == len(self.layers) - 1 and self.final_norm: - x = self.norm1(x) + x = self.ln1(x) if i in self.out_indices: - B, _, C = x.shape - if self.with_cls_token: - patch_token = x[:, 1:].reshape(B, *patch_resolution, C) - patch_token = patch_token.permute(0, 3, 1, 2) - cls_token = x[:, 0] - else: - patch_token = x.reshape(B, *patch_resolution, C) - patch_token = patch_token.permute(0, 3, 1, 2) - cls_token = None - - if self.avg_token: - patch_token = patch_token.permute(0, 2, 3, 1) - patch_token = patch_token.reshape( - B, patch_resolution[0] * patch_resolution[1], - C).mean(dim=1) - patch_token = self.norm2(patch_token) - if self.output_cls_token: - out = [patch_token, cls_token] - else: - out = patch_token - outs.append(out) + outs.append(self._format_output(x, patch_resolution)) return tuple(outs) + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return self.ln2(patch_token.mean(dim=1)) + def _prepare_relative_position_bias_table(self, state_dict, prefix, *args, **kwargs): from mmengine.logging import MMLogger diff --git a/mmpretrain/models/backbones/deit.py b/mmpretrain/models/backbones/deit.py index 20cd32e5..9ae34082 100644 --- a/mmpretrain/models/backbones/deit.py +++ b/mmpretrain/models/backbones/deit.py @@ -43,10 +43,18 @@ class DistilledVisionTransformer(VisionTransformer): Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. - with_cls_token (bool): Whether concatenating class token into image - tokens as transformer input. Defaults to True. - output_cls_token (bool): Whether output the cls_token. If set True, - ``with_cls_token`` must be True. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: A tuple with the class token and the + distillation token. The shapes of both tensor are (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. @@ -55,11 +63,15 @@ class DistilledVisionTransformer(VisionTransformer): init_cfg (dict, optional): Initialization config dict. Defaults to None. """ - num_extra_tokens = 2 # cls_token, dist_token + num_extra_tokens = 2 # class token and distillation token def __init__(self, arch='deit-base', *args, **kwargs): super(DistilledVisionTransformer, self).__init__( - arch=arch, *args, **kwargs) + arch=arch, + with_cls_token=True, + *args, + **kwargs, + ) self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) def forward(self, x): @@ -78,37 +90,24 @@ class DistilledVisionTransformer(VisionTransformer): num_extra_tokens=self.num_extra_tokens) x = self.drop_after_pos(x) - if not self.with_cls_token: - # Remove class token for transformer encoder input - x = x[:, 2:] - outs = [] for i, layer in enumerate(self.layers): x = layer(x) if i == len(self.layers) - 1 and self.final_norm: - x = self.norm1(x) + x = self.ln1(x) if i in self.out_indices: - B, _, C = x.shape - if self.with_cls_token: - patch_token = x[:, 2:].reshape(B, *patch_resolution, C) - patch_token = patch_token.permute(0, 3, 1, 2) - cls_token = x[:, 0] - dist_token = x[:, 1] - else: - patch_token = x.reshape(B, *patch_resolution, C) - patch_token = patch_token.permute(0, 3, 1, 2) - cls_token = None - dist_token = None - if self.output_cls_token: - out = [patch_token, cls_token, dist_token] - else: - out = patch_token - outs.append(out) + outs.append(self._format_output(x, patch_resolution)) return tuple(outs) + def _format_output(self, x, hw): + if self.out_type == 'cls_token': + return x[:, 0], x[:, 1] + + return super()._format_output(x, hw) + def init_weights(self): super(DistilledVisionTransformer, self).init_weights() diff --git a/mmpretrain/models/backbones/deit3.py b/mmpretrain/models/backbones/deit3.py index 68ee0cb0..9be36279 100644 --- a/mmpretrain/models/backbones/deit3.py +++ b/mmpretrain/models/backbones/deit3.py @@ -3,7 +3,7 @@ from typing import Sequence import numpy as np import torch -from mmcv.cnn import Linear, build_activation_layer, build_norm_layer +from mmcv.cnn import Linear, build_activation_layer from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.bricks.transformer import PatchEmbed from mmengine.model import BaseModule, ModuleList, Sequential @@ -11,7 +11,8 @@ from mmengine.utils import deprecated_api_warning from torch import nn from mmpretrain.registry import MODELS -from ..utils import LayerScale, MultiheadAttention, resize_pos_embed, to_2tuple +from ..utils import (LayerScale, MultiheadAttention, build_norm_layer, + resize_pos_embed, to_2tuple) from .vision_transformer import VisionTransformer @@ -149,9 +150,7 @@ class DeiT3TransformerEncoderLayer(BaseModule): self.embed_dims = embed_dims - self.norm1_name, norm1 = build_norm_layer( - norm_cfg, self.embed_dims, postfix=1) - self.add_module(self.norm1_name, norm1) + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) self.attn = MultiheadAttention( embed_dims=embed_dims, @@ -162,9 +161,7 @@ class DeiT3TransformerEncoderLayer(BaseModule): qkv_bias=qkv_bias, use_layer_scale=use_layer_scale) - self.norm2_name, norm2 = build_norm_layer( - norm_cfg, self.embed_dims, postfix=2) - self.add_module(self.norm2_name, norm2) + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) self.ffn = DeiT3FFN( embed_dims=embed_dims, @@ -175,14 +172,6 @@ class DeiT3TransformerEncoderLayer(BaseModule): act_cfg=act_cfg, use_layer_scale=use_layer_scale) - @property - def norm1(self): - return getattr(self, self.norm1_name) - - @property - def norm2(self): - return getattr(self, self.norm2_name) - def init_weights(self): super(DeiT3TransformerEncoderLayer, self).init_weights() for m in self.ffn.modules(): @@ -191,8 +180,8 @@ class DeiT3TransformerEncoderLayer(BaseModule): nn.init.normal_(m.bias, std=1e-6) def forward(self, x): - x = x + self.attn(self.norm1(x)) - x = self.ffn(self.norm2(x), identity=x) + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln1(x), identity=x) return x @@ -237,10 +226,19 @@ class DeiT3(VisionTransformer): Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. with_cls_token (bool): Whether concatenating class token into image tokens as transformer input. Defaults to True. - output_cls_token (bool): Whether output the cls_token. If set True, - ``with_cls_token`` must be True. Defaults to True. use_layer_scale (bool): Whether to use layer_scale in DeiT3. Defaults to True. interpolate_mode (str): Select the interpolate mode for position @@ -288,9 +286,7 @@ class DeiT3(VisionTransformer): 'feedforward_channels': 5120 }), } - # not using num_extra_tokens in deit3 because adding cls tokens after - # adding pos_embed - num_extra_tokens = 0 + num_extra_tokens = 1 # class token def __init__(self, arch='base', @@ -303,8 +299,8 @@ class DeiT3(VisionTransformer): qkv_bias=True, norm_cfg=dict(type='LN', eps=1e-6), final_norm=True, + out_type='cls_token', with_cls_token=True, - output_cls_token=True, use_layer_scale=True, interpolate_mode='bicubic', patch_cfg=dict(), @@ -343,13 +339,21 @@ class DeiT3(VisionTransformer): self.patch_resolution = self.patch_embed.init_out_size num_patches = self.patch_resolution[0] * self.patch_resolution[1] + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + # Set cls token - if output_cls_token: - assert with_cls_token is True, f'with_cls_token must be True if' \ - f'set output_cls_token to True, but got {with_cls_token}' - self.with_cls_token = with_cls_token - self.output_cls_token = output_cls_token - self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') # Set position embedding self.interpolate_mode = interpolate_mode @@ -393,9 +397,7 @@ class DeiT3(VisionTransformer): self.final_norm = final_norm if final_norm: - self.norm1_name, norm1 = build_norm_layer( - norm_cfg, self.embed_dims, postfix=1) - self.add_module(self.norm1_name, norm1) + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) def forward(self, x): B = x.shape[0] @@ -406,38 +408,47 @@ class DeiT3(VisionTransformer): self.patch_resolution, patch_resolution, mode=self.interpolate_mode, - num_extra_tokens=self.num_extra_tokens) + num_extra_tokens=0) x = self.drop_after_pos(x) - # stole cls_tokens impl from Phil Wang, thanks - cls_tokens = self.cls_token.expand(B, -1, -1) - x = torch.cat((cls_tokens, x), dim=1) - - if not self.with_cls_token: - # Remove class token for transformer encoder input - x = x[:, 1:] + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) outs = [] for i, layer in enumerate(self.layers): x = layer(x) if i == len(self.layers) - 1 and self.final_norm: - x = self.norm1(x) + x = self.ln1(x) if i in self.out_indices: - B, _, C = x.shape - if self.with_cls_token: - patch_token = x[:, 1:].reshape(B, *patch_resolution, C) - patch_token = patch_token.permute(0, 3, 1, 2) - cls_token = x[:, 0] - else: - patch_token = x.reshape(B, *patch_resolution, C) - patch_token = patch_token.permute(0, 3, 1, 2) - cls_token = None - if self.output_cls_token: - out = [patch_token, cls_token] - else: - out = patch_token - outs.append(out) + outs.append(self._format_output(x, patch_resolution)) return tuple(outs) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1]))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed( + state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + num_extra_tokens=0, # The cls token adding is after pos_embed + ) diff --git a/mmpretrain/models/backbones/mixmim.py b/mmpretrain/models/backbones/mixmim.py index 8b520bd4..2c67aa0c 100644 --- a/mmpretrain/models/backbones/mixmim.py +++ b/mmpretrain/models/backbones/mixmim.py @@ -196,7 +196,7 @@ class MixMIMBlock(TransformerEncoderLayer): B, L, C = x.shape shortcut = x - x = self.norm1(x) + x = self.ln1(x) x = x.view(B, H, W, C) # partition windows diff --git a/mmpretrain/models/backbones/revvit.py b/mmpretrain/models/backbones/revvit.py index ec3b2ce6..f2e6c28c 100644 --- a/mmpretrain/models/backbones/revvit.py +++ b/mmpretrain/models/backbones/revvit.py @@ -1,10 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import sys -from typing import Sequence import numpy as np import torch -from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.bricks.transformer import FFN, PatchEmbed from mmengine.model import BaseModule, ModuleList @@ -14,7 +12,8 @@ from torch.autograd import Function as Function from mmpretrain.models.backbones.base_backbone import BaseBackbone from mmpretrain.registry import MODELS -from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple +from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed, + to_2tuple) class RevBackProp(Function): @@ -152,9 +151,7 @@ class RevTransformerEncoderLayer(BaseModule): self.drop_path_cfg = dict(type='DropPath', drop_prob=drop_path_rate) self.embed_dims = embed_dims - self.norm1_name, norm1 = build_norm_layer( - norm_cfg, self.embed_dims, postfix=1) - self.add_module(self.norm1_name, norm1) + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) self.attn = MultiheadAttention( embed_dims=embed_dims, @@ -163,9 +160,7 @@ class RevTransformerEncoderLayer(BaseModule): proj_drop=drop_rate, qkv_bias=qkv_bias) - self.norm2_name, norm2 = build_norm_layer( - norm_cfg, self.embed_dims, postfix=2) - self.add_module(self.norm2_name, norm2) + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) self.ffn = FFN( embed_dims=embed_dims, @@ -178,14 +173,6 @@ class RevTransformerEncoderLayer(BaseModule): self.layer_id = layer_id self.seeds = {} - @property - def norm1(self): - return getattr(self, self.norm1_name) - - @property - def norm2(self): - return getattr(self, self.norm2_name) - def init_weights(self): super(RevTransformerEncoderLayer, self).init_weights() for m in self.ffn.modules(): @@ -215,13 +202,13 @@ class RevTransformerEncoderLayer(BaseModule): Implementation of Reversible TransformerEncoderLayer ` - x = x + self.attn(self.norm1(x)) - x = self.ffn(self.norm2(x), identity=x) + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln2(x), identity=x) ` """ self.seed_cuda('attn') # attention output - f_x2 = self.attn(self.norm1(x2)) + f_x2 = self.attn(self.ln1(x2)) # apply droppath on attention output self.seed_cuda('droppath') f_x2_dropped = build_dropout(self.drop_path_cfg)(f_x2) @@ -233,7 +220,7 @@ class RevTransformerEncoderLayer(BaseModule): # ffn output self.seed_cuda('ffn') - g_y1 = self.ffn(self.norm2(y1)) + g_y1 = self.ffn(self.ln2(y1)) # apply droppath on ffn output torch.manual_seed(self.seeds['droppath']) g_y1_dropped = build_dropout(self.drop_path_cfg)(g_y1) @@ -259,7 +246,7 @@ class RevTransformerEncoderLayer(BaseModule): y1.requires_grad = True torch.manual_seed(self.seeds['ffn']) - g_y1 = self.ffn(self.norm2(y1)) + g_y1 = self.ffn(self.ln2(y1)) torch.manual_seed(self.seeds['droppath']) g_y1 = build_dropout(self.drop_path_cfg)(g_y1) @@ -280,7 +267,7 @@ class RevTransformerEncoderLayer(BaseModule): x2.requires_grad = True torch.manual_seed(self.seeds['attn']) - f_x2 = self.attn(self.norm1(x2)) + f_x2 = self.attn(self.ln1(x2)) torch.manual_seed(self.seeds['droppath']) f_x2 = build_dropout(self.drop_path_cfg)(f_x2) @@ -337,7 +324,8 @@ class TwoStreamFusion(nn.Module): class RevVisionTransformer(BaseBackbone): """Reversible Vision Transformer. - A PyTorch implementation of : `Reversible Vision Transformers `_ # noqa: E501 + A PyTorch implementation of : `Reversible Vision Transformers + `_ # noqa: E501 Args: arch (str | dict): Vision Transformer architecture. If use string, @@ -357,8 +345,6 @@ class RevVisionTransformer(BaseBackbone): patch_size (int | tuple): The patch size in patch embedding. Defaults to 16. in_channels (int): The num of input channels. Defaults to 3. - out_indices (Sequence | int): Output from which stages. - Defaults to -1, means the last stage. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. @@ -368,15 +354,21 @@ class RevVisionTransformer(BaseBackbone): Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"avg_featmap"``. with_cls_token (bool): Whether concatenating class token into image - tokens as transformer input. Defaults to True. - avg_token (bool): Whether or not to use the mean patch token for - classification. If True, the model will only take the average - of all patch tokens. Defaults to False. + tokens as transformer input. Defaults to False. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. - output_cls_token (bool): Whether output the cls_token. If set True, - ``with_cls_token`` must be True. Defaults to True. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. @@ -443,24 +435,22 @@ class RevVisionTransformer(BaseBackbone): 'feedforward_channels': 768 * 4 }), } - # Some structures have multiple extra tokens, like DeiT. - num_extra_tokens = 1 # cls_token + num_extra_tokens = 0 # The official RevViT doesn't have class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} def __init__(self, arch='base', img_size=224, patch_size=16, in_channels=3, - out_indices=-1, drop_rate=0., drop_path_rate=0., qkv_bias=True, norm_cfg=dict(type='LN', eps=1e-6), final_norm=True, + out_type='avg_featmap', with_cls_token=False, - avg_token=True, frozen_stages=-1, - output_cls_token=False, interpolate_mode='bicubic', patch_cfg=dict(), layer_cfgs=dict(), @@ -501,15 +491,22 @@ class RevVisionTransformer(BaseBackbone): self.patch_resolution = self.patch_embed.init_out_size num_patches = self.patch_resolution[0] * self.patch_resolution[1] - # Set cls token - if output_cls_token: - assert with_cls_token is True, f'with_cls_token must be True if' \ - f'set output_cls_token to True, but got {with_cls_token}' - self.with_cls_token = with_cls_token - assert with_cls_token is False, 'with_cls_token=True is not supported' + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type - self.output_cls_token = output_cls_token - self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + # Set cls token + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + self.num_extra_tokens = 1 + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') # Set position embedding self.interpolate_mode = interpolate_mode @@ -520,20 +517,6 @@ class RevVisionTransformer(BaseBackbone): self.drop_after_pos = nn.Dropout(p=drop_rate) - if isinstance(out_indices, int): - out_indices = [out_indices] - assert isinstance(out_indices, Sequence), \ - f'"out_indices" must by a sequence or int, ' \ - f'get {type(out_indices)} instead.' - for i, index in enumerate(out_indices): - if index < 0: - out_indices[i] = self.num_layers + index - assert 0 <= out_indices[i] <= self.num_layers, \ - f'Invalid out_indices {index}' - self.out_indices = out_indices - assert out_indices == [-1] or out_indices == [self.num_layers - 1], \ - f'only support output last layer current, but got {out_indices}' - # stochastic depth decay rule dpr = np.linspace(0, drop_path_rate, self.num_layers) @@ -560,20 +543,12 @@ class RevVisionTransformer(BaseBackbone): self.frozen_stages = frozen_stages self.final_norm = final_norm if final_norm: - self.norm1_name, norm1 = build_norm_layer( - norm_cfg, self.embed_dims * 2, postfix=1) - self.add_module(self.norm1_name, norm1) - - self.avg_token = avg_token + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims * 2) # freeze stages only when self.frozen_stages > 0 if self.frozen_stages > 0: self._freeze_stages() - @property - def norm1(self): - return getattr(self, self.norm1_name) - def init_weights(self): super(RevVisionTransformer, self).init_weights() if not (isinstance(self.init_cfg, dict) @@ -618,7 +593,8 @@ class RevVisionTransformer(BaseBackbone): for param in self.patch_embed.parameters(): param.requires_grad = False # freeze cls_token - # self.cls_token.requires_grad = False + if self.cls_token is not None: + self.cls_token.requires_grad = False # freeze layers for i in range(1, self.frozen_stages + 1): m = self.layers[i - 1] @@ -627,17 +603,17 @@ class RevVisionTransformer(BaseBackbone): param.requires_grad = False # freeze the last layer norm if self.frozen_stages == len(self.layers) and self.final_norm: - self.norm1.eval() - for param in self.norm1.parameters(): + self.ln1.eval() + for param in self.ln1.parameters(): param.requires_grad = False def forward(self, x): B = x.shape[0] x, patch_resolution = self.patch_embed(x) - # stole cls_tokens impl from Phil Wang, thanks - cls_tokens = self.cls_token.expand(B, -1, -1) - x = torch.cat((cls_tokens, x), dim=1) + if self.cls_token is not None: + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) x = x + resize_pos_embed( self.pos_embed, @@ -647,10 +623,6 @@ class RevVisionTransformer(BaseBackbone): num_extra_tokens=self.num_extra_tokens) x = self.drop_after_pos(x) - if not self.with_cls_token: - # Remove class token for transformer encoder input - x = x[:, 1:] - x = torch.cat([x, x], dim=-1) # forward with different conditions @@ -664,33 +636,10 @@ class RevVisionTransformer(BaseBackbone): x = executing_fn(x, self.layers, []) if self.final_norm: - x = self.norm1(x) + x = self.ln1(x) x = self.fusion_layer(x) - if self.with_cls_token: - # RevViT does not allow cls_token - raise NotImplementedError - else: - # (B, H, W, C) - _, __, C = x.shape - patch_token = x.reshape(B, *patch_resolution, C) - # (B, C, H, W) - patch_token = patch_token.permute(0, 3, 1, 2) - cls_token = None - - if self.avg_token: - # (B, H, W, C) - patch_token = patch_token.permute(0, 2, 3, 1) - # (B, L, C) -> (B, C) - patch_token = patch_token.reshape( - B, patch_resolution[0] * patch_resolution[1], C).mean(dim=1) - - if self.output_cls_token: - out = [patch_token, cls_token] - else: - out = patch_token - - return tuple([out]) + return (self._format_output(x, patch_resolution), ) @staticmethod def _forward_vanilla_bp(hidden_state, layers, buffer=[]): @@ -706,3 +655,17 @@ class RevVisionTransformer(BaseBackbone): attn_out, ffn_out = layer(attn_out, ffn_out) return torch.cat([attn_out, ffn_out], dim=-1) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return patch_token.mean(dim=1) diff --git a/mmpretrain/models/backbones/t2t_vit.py b/mmpretrain/models/backbones/t2t_vit.py index 4195a7de..288ef0dc 100644 --- a/mmpretrain/models/backbones/t2t_vit.py +++ b/mmpretrain/models/backbones/t2t_vit.py @@ -5,13 +5,13 @@ from typing import Sequence import numpy as np import torch import torch.nn as nn -from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN from mmengine.model import BaseModule, ModuleList from mmengine.model.weight_init import trunc_normal_ from mmpretrain.registry import MODELS -from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple +from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed, + to_2tuple) from .base_backbone import BaseBackbone @@ -70,9 +70,7 @@ class T2TTransformerLayer(BaseModule): self.v_shortcut = True if input_dims is not None else False input_dims = input_dims or embed_dims - self.norm1_name, norm1 = build_norm_layer( - norm_cfg, input_dims, postfix=1) - self.add_module(self.norm1_name, norm1) + self.ln1 = build_norm_layer(norm_cfg, input_dims) self.attn = MultiheadAttention( input_dims=input_dims, @@ -85,9 +83,7 @@ class T2TTransformerLayer(BaseModule): qk_scale=qk_scale or (input_dims // num_heads)**-0.5, v_shortcut=self.v_shortcut) - self.norm2_name, norm2 = build_norm_layer( - norm_cfg, embed_dims, postfix=2) - self.add_module(self.norm2_name, norm2) + self.ln2 = build_norm_layer(norm_cfg, embed_dims) self.ffn = FFN( embed_dims=embed_dims, @@ -97,20 +93,12 @@ class T2TTransformerLayer(BaseModule): dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), act_cfg=act_cfg) - @property - def norm1(self): - return getattr(self, self.norm1_name) - - @property - def norm2(self): - return getattr(self, self.norm2_name) - def forward(self, x): if self.v_shortcut: - x = self.attn(self.norm1(x)) + x = self.attn(self.ln1(x)) else: - x = x + self.attn(self.norm1(x)) - x = self.ffn(self.norm2(x), identity=x) + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln2(x), identity=x) return x @@ -265,10 +253,19 @@ class T2T_ViT(BaseBackbone): ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. with_cls_token (bool): Whether concatenating class token into image tokens as transformer input. Defaults to True. - output_cls_token (bool): Whether output the cls_token. If set True, - ``with_cls_token`` must be True. Defaults to True. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". t2t_cfg (dict): Extra config of Tokens-to-Token module. @@ -278,7 +275,7 @@ class T2T_ViT(BaseBackbone): init_cfg (dict, optional): The Config for initialization. Defaults to None. """ - num_extra_tokens = 1 # cls_token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} def __init__(self, img_size=224, @@ -290,13 +287,13 @@ class T2T_ViT(BaseBackbone): drop_path_rate=0., norm_cfg=dict(type='LN'), final_norm=True, + out_type='cls_token', with_cls_token=True, - output_cls_token=True, interpolate_mode='bicubic', t2t_cfg=dict(), layer_cfgs=dict(), init_cfg=None): - super(T2T_ViT, self).__init__(init_cfg) + super().__init__(init_cfg) # Token-to-Token Module self.tokens_to_token = T2TModule( @@ -307,13 +304,22 @@ class T2T_ViT(BaseBackbone): self.patch_resolution = self.tokens_to_token.init_out_size num_patches = self.patch_resolution[0] * self.patch_resolution[1] + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + # Set cls token - if output_cls_token: - assert with_cls_token is True, f'with_cls_token must be True if' \ - f'set output_cls_token to True, but got {with_cls_token}' - self.with_cls_token = with_cls_token - self.output_cls_token = output_cls_token - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + self.num_extra_tokens = 1 + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') # Set position embedding self.interpolate_mode = interpolate_mode @@ -360,7 +366,7 @@ class T2T_ViT(BaseBackbone): self.final_norm = final_norm if final_norm: - self.norm = build_norm_layer(norm_cfg, embed_dims)[1] + self.norm = build_norm_layer(norm_cfg, embed_dims) else: self.norm = nn.Identity() @@ -401,9 +407,10 @@ class T2T_ViT(BaseBackbone): B = x.shape[0] x, patch_resolution = self.tokens_to_token(x) - # stole cls_tokens impl from Phil Wang, thanks - cls_tokens = self.cls_token.expand(B, -1, -1) - x = torch.cat((cls_tokens, x), dim=1) + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) x = x + resize_pos_embed( self.pos_embed, @@ -413,10 +420,6 @@ class T2T_ViT(BaseBackbone): num_extra_tokens=self.num_extra_tokens) x = self.drop_after_pos(x) - if not self.with_cls_token: - # Remove class token for transformer encoder input - x = x[:, 1:] - outs = [] for i, layer in enumerate(self.encoder): x = layer(x) @@ -425,19 +428,20 @@ class T2T_ViT(BaseBackbone): x = self.norm(x) if i in self.out_indices: - B, _, C = x.shape - if self.with_cls_token: - patch_token = x[:, 1:].reshape(B, *patch_resolution, C) - patch_token = patch_token.permute(0, 3, 1, 2) - cls_token = x[:, 0] - else: - patch_token = x.reshape(B, *patch_resolution, C) - patch_token = patch_token.permute(0, 3, 1, 2) - cls_token = None - if self.output_cls_token: - out = [patch_token, cls_token] - else: - out = patch_token - outs.append(out) + outs.append(self._format_output(x, patch_resolution)) return tuple(outs) + + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return patch_token.mean(dim=1) diff --git a/mmpretrain/models/backbones/vision_transformer.py b/mmpretrain/models/backbones/vision_transformer.py index 50231c99..a46d32a5 100644 --- a/mmpretrain/models/backbones/vision_transformer.py +++ b/mmpretrain/models/backbones/vision_transformer.py @@ -4,13 +4,13 @@ from typing import Sequence import numpy as np import torch import torch.nn as nn -from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import FFN, PatchEmbed from mmengine.model import BaseModule, ModuleList from mmengine.model.weight_init import trunc_normal_ from mmpretrain.registry import MODELS -from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple +from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed, + to_2tuple) from .base_backbone import BaseBackbone @@ -53,9 +53,7 @@ class TransformerEncoderLayer(BaseModule): self.embed_dims = embed_dims - self.norm1_name, norm1 = build_norm_layer( - norm_cfg, self.embed_dims, postfix=1) - self.add_module(self.norm1_name, norm1) + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) self.attn = MultiheadAttention( embed_dims=embed_dims, @@ -65,9 +63,7 @@ class TransformerEncoderLayer(BaseModule): dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), qkv_bias=qkv_bias) - self.norm2_name, norm2 = build_norm_layer( - norm_cfg, self.embed_dims, postfix=2) - self.add_module(self.norm2_name, norm2) + self.ln2 = build_norm_layer(norm_cfg, self.embed_dims) self.ffn = FFN( embed_dims=embed_dims, @@ -79,11 +75,11 @@ class TransformerEncoderLayer(BaseModule): @property def norm1(self): - return getattr(self, self.norm1_name) + return self.ln1 @property def norm2(self): - return getattr(self, self.norm2_name) + return self.ln2 def init_weights(self): super(TransformerEncoderLayer, self).init_weights() @@ -93,8 +89,8 @@ class TransformerEncoderLayer(BaseModule): nn.init.normal_(m.bias, std=1e-6) def forward(self, x): - x = x + self.attn(self.norm1(x)) - x = self.ffn(self.norm2(x), identity=x) + x = x + self.attn(self.ln1(x)) + x = self.ffn(self.ln2(x), identity=x) return x @@ -134,15 +130,21 @@ class VisionTransformer(BaseBackbone): Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + Defaults to ``"cls_token"``. with_cls_token (bool): Whether concatenating class token into image tokens as transformer input. Defaults to True. - avg_token (bool): Whether or not to use the mean patch token for - classification. If True, the model will only take the average - of all patch tokens. Defaults to False. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. - output_cls_token (bool): Whether output the cls_token. If set True, - ``with_cls_token`` must be True. Defaults to True. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. @@ -215,8 +217,8 @@ class VisionTransformer(BaseBackbone): 'feedforward_channels': 768 * 4 }), } - # Some structures have multiple extra tokens, like DeiT. - num_extra_tokens = 1 # cls_token + num_extra_tokens = 1 # class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} def __init__(self, arch='base', @@ -229,10 +231,9 @@ class VisionTransformer(BaseBackbone): qkv_bias=True, norm_cfg=dict(type='LN', eps=1e-6), final_norm=True, + out_type='cls_token', with_cls_token=True, - avg_token=False, frozen_stages=-1, - output_cls_token=True, interpolate_mode='bicubic', patch_cfg=dict(), layer_cfgs=dict(), @@ -272,13 +273,21 @@ class VisionTransformer(BaseBackbone): self.patch_resolution = self.patch_embed.init_out_size num_patches = self.patch_resolution[0] * self.patch_resolution[1] + # Set out type + if out_type not in self.OUT_TYPES: + raise ValueError(f'Unsupported `out_type` {out_type}, please ' + f'choose from {self.OUT_TYPES}') + self.out_type = out_type + # Set cls token - if output_cls_token: - assert with_cls_token is True, f'with_cls_token must be True if' \ - f'set output_cls_token to True, but got {with_cls_token}' - self.with_cls_token = with_cls_token - self.output_cls_token = output_cls_token - self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + if with_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + elif out_type != 'cls_token': + self.cls_token = None + self.num_extra_tokens = 0 + else: + raise ValueError( + 'with_cls_token must be True when `out_type="cls_token"`.') # Set position embedding self.interpolate_mode = interpolate_mode @@ -322,34 +331,25 @@ class VisionTransformer(BaseBackbone): self.frozen_stages = frozen_stages if pre_norm: - _, norm_layer = build_norm_layer( - norm_cfg, self.embed_dims, postfix=1) + self.pre_norm = build_norm_layer(norm_cfg, self.embed_dims) else: - norm_layer = nn.Identity() - self.add_module('pre_norm', norm_layer) + self.pre_norm = nn.Identity() self.final_norm = final_norm if final_norm: - self.norm1_name, norm1 = build_norm_layer( - norm_cfg, self.embed_dims, postfix=1) - self.add_module(self.norm1_name, norm1) + self.ln1 = build_norm_layer(norm_cfg, self.embed_dims) - self.avg_token = avg_token - if avg_token: - self.norm2_name, norm2 = build_norm_layer( - norm_cfg, self.embed_dims, postfix=2) - self.add_module(self.norm2_name, norm2) # freeze stages only when self.frozen_stages > 0 if self.frozen_stages > 0: self._freeze_stages() @property def norm1(self): - return getattr(self, self.norm1_name) + return self.ln1 @property def norm2(self): - return getattr(self, self.norm2_name) + return self.ln2 def init_weights(self): super(VisionTransformer, self).init_weights() @@ -407,17 +407,19 @@ class VisionTransformer(BaseBackbone): param.requires_grad = False # freeze the last layer norm if self.frozen_stages == len(self.layers) and self.final_norm: - self.norm1.eval() - for param in self.norm1.parameters(): + self.ln1.eval() + for param in self.ln1.parameters(): param.requires_grad = False def forward(self, x): B = x.shape[0] x, patch_resolution = self.patch_embed(x) - # stole cls_tokens impl from Phil Wang, thanks - cls_tokens = self.cls_token.expand(B, -1, -1) - x = torch.cat((cls_tokens, x), dim=1) + if self.cls_token is not None: + # stole cls_tokens impl from Phil Wang, thanks + cls_token = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_token, x), dim=1) + x = x + resize_pos_embed( self.pos_embed, self.patch_resolution, @@ -427,41 +429,33 @@ class VisionTransformer(BaseBackbone): x = self.drop_after_pos(x) x = self.pre_norm(x) - if not self.with_cls_token: - # Remove class token for transformer encoder input - x = x[:, 1:] outs = [] for i, layer in enumerate(self.layers): x = layer(x) if i == len(self.layers) - 1 and self.final_norm: - x = self.norm1(x) + x = self.ln1(x) if i in self.out_indices: - B, _, C = x.shape - if self.with_cls_token: - patch_token = x[:, 1:].reshape(B, *patch_resolution, C) - patch_token = patch_token.permute(0, 3, 1, 2) - cls_token = x[:, 0] - else: - patch_token = x.reshape(B, *patch_resolution, C) - patch_token = patch_token.permute(0, 3, 1, 2) - cls_token = None - if self.avg_token: - patch_token = patch_token.permute(0, 2, 3, 1) - patch_token = patch_token.reshape( - B, patch_resolution[0] * patch_resolution[1], - C).mean(dim=1) - patch_token = self.norm2(patch_token) - if self.output_cls_token: - out = [patch_token, cls_token] - else: - out = patch_token - outs.append(out) + outs.append(self._format_output(x, patch_resolution)) return tuple(outs) + def _format_output(self, x, hw): + if self.out_type == 'raw': + return x + if self.out_type == 'cls_token': + return x[:, 0] + + patch_token = x[:, self.num_extra_tokens:] + if self.out_type == 'featmap': + B = x.size(0) + # (B, N, C) -> (B, H, W, C) -> (B, C, H, W) + return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) + if self.out_type == 'avg_featmap': + return patch_token.mean(dim=1) + def get_layer_depth(self, param_name: str, prefix: str = ''): """Get the layer-wise depth of a parameter. diff --git a/mmpretrain/models/heads/deit_head.py b/mmpretrain/models/heads/deit_head.py index c2ce9d4c..a96f6e15 100644 --- a/mmpretrain/models/heads/deit_head.py +++ b/mmpretrain/models/heads/deit_head.py @@ -47,7 +47,12 @@ class DeiTClsHead(VisionTransformerClsHead): the feature of a backbone stage. In ``DeiTClsHead``, we obtain the feature of the last stage and forward in hidden layer if exists. """ - _, cls_token, dist_token = feats[-1] + feat = feats[-1] # Obtain feature of the last scale. + # For backward-compatibility with the previous ViT output + if len(feat) == 3: + _, cls_token, dist_token = feat + else: + cls_token, dist_token = feat if self.hidden_dim is None: return cls_token, dist_token else: diff --git a/mmpretrain/models/heads/vision_transformer_head.py b/mmpretrain/models/heads/vision_transformer_head.py index a7194d91..83e8fca1 100644 --- a/mmpretrain/models/heads/vision_transformer_head.py +++ b/mmpretrain/models/heads/vision_transformer_head.py @@ -80,7 +80,9 @@ class VisionTransformerClsHead(ClsHead): obtain the feature of the last stage and forward in hidden layer if exists. """ - _, cls_token = feats[-1] + feat = feats[-1] # Obtain feature of the last scale. + # For backward-compatibility with the previous ViT output + cls_token = feat[-1] if isinstance(feat, list) else feat if self.hidden_dim is None: return cls_token else: diff --git a/mmpretrain/models/necks/mae_neck.py b/mmpretrain/models/necks/mae_neck.py index 01c7b3da..773692dc 100644 --- a/mmpretrain/models/necks/mae_neck.py +++ b/mmpretrain/models/necks/mae_neck.py @@ -183,7 +183,6 @@ class ClsBatchNormNeck(BaseModule): self, inputs: Tuple[List[torch.Tensor]]) -> Tuple[List[torch.Tensor]]: """The forward function.""" - # Only apply batch norm to cls token, which is the second tensor in - # each item of the tuple - inputs = [[input_[0], self.bn(input_[1])] for input_ in inputs] + # Only apply batch norm to cls_token + inputs = [self.bn(input_) for input_ in inputs] return tuple(inputs) diff --git a/mmpretrain/models/necks/nonlinear_neck.py b/mmpretrain/models/necks/nonlinear_neck.py index 311a9afd..ef684d39 100644 --- a/mmpretrain/models/necks/nonlinear_neck.py +++ b/mmpretrain/models/necks/nonlinear_neck.py @@ -32,8 +32,6 @@ class NonLinearNeck(BaseModule): Defaults to False. with_avg_pool (bool): Whether to apply the global average pooling after backbone. Defaults to True. - vit_backbone (bool): The key to indicate whether the upstream backbone - is ViT. Defaults to False. norm_cfg (dict): Dictionary to construct and config norm layer. Defaults to dict(type='SyncBN'). init_cfg (dict or list[dict], optional): Initialization config dict. @@ -50,7 +48,6 @@ class NonLinearNeck(BaseModule): with_last_bn_affine: bool = True, with_last_bias: bool = False, with_avg_pool: bool = True, - vit_backbone: bool = False, norm_cfg: dict = dict(type='SyncBN'), init_cfg: Optional[Union[dict, List[dict]]] = [ dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) @@ -58,7 +55,6 @@ class NonLinearNeck(BaseModule): ) -> None: super(NonLinearNeck, self).__init__(init_cfg) self.with_avg_pool = with_avg_pool - self.vit_backbone = vit_backbone if with_avg_pool: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.relu = nn.ReLU(inplace=True) @@ -104,8 +100,6 @@ class NonLinearNeck(BaseModule): """ assert len(x) == 1 x = x[0] - if self.vit_backbone: - x = x[-1] if self.with_avg_pool: x = self.avgpool(x) x = x.view(x.size(0), -1) diff --git a/mmpretrain/models/selfsup/beit.py b/mmpretrain/models/selfsup/beit.py index dfa3dba1..13b39bdf 100644 --- a/mmpretrain/models/selfsup/beit.py +++ b/mmpretrain/models/selfsup/beit.py @@ -140,15 +140,21 @@ class BEiTPretrainViT(BEiTViT): Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. with_cls_token (bool): Whether concatenating class token into image tokens as transformer input. Defaults to True. - avg_token (bool): Whether or not to use the mean patch token for - classification. If True, the model will only take the average - of all patch tokens. Defaults to False. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. - output_cls_token (bool): Whether output the cls_token. If set True, - ``with_cls_token`` must be True. Defaults to True. use_abs_pos_emb (bool): Whether or not use absolute position embedding. Defaults to False. use_rel_pos_bias (bool): Whether or not use relative position bias. @@ -176,9 +182,8 @@ class BEiTPretrainViT(BEiTViT): drop_path_rate: float = 0, norm_cfg: dict = dict(type='LN', eps=1e-6), final_norm: bool = True, - avg_token: bool = False, + out_type: str = 'avg_featmap', frozen_stages: int = -1, - output_cls_token: bool = True, use_abs_pos_emb: bool = False, use_rel_pos_bias: bool = False, use_shared_rel_pos_bias: bool = True, @@ -197,9 +202,9 @@ class BEiTPretrainViT(BEiTViT): drop_path_rate=drop_path_rate, norm_cfg=norm_cfg, final_norm=final_norm, - avg_token=avg_token, + out_type=out_type, + with_cls_token=True, frozen_stages=frozen_stages, - output_cls_token=output_cls_token, use_abs_pos_emb=use_abs_pos_emb, use_shared_rel_pos_bias=use_shared_rel_pos_bias, use_rel_pos_bias=use_rel_pos_bias, diff --git a/mmpretrain/models/selfsup/cae.py b/mmpretrain/models/selfsup/cae.py index b58e6312..2c7cfeae 100644 --- a/mmpretrain/models/selfsup/cae.py +++ b/mmpretrain/models/selfsup/cae.py @@ -217,8 +217,17 @@ class CAEPretrainViT(BEiTViT): Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. - output_cls_token (bool): Whether output the cls_token. If set True, - `with_cls_token` must be True. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". layer_scale_init_value (float, optional): The init value of gamma in @@ -242,10 +251,8 @@ class CAEPretrainViT(BEiTViT): bias: bool = 'qv_bias', norm_cfg: dict = dict(type='LN', eps=1e-6), final_norm: bool = True, - with_cls_token: bool = True, - avg_token: bool = False, + out_type: str = 'avg_featmap', frozen_stages: int = -1, - output_cls_token: bool = True, use_abs_pos_emb: bool = True, use_rel_pos_bias: bool = False, use_shared_rel_pos_bias: bool = False, @@ -270,10 +277,9 @@ class CAEPretrainViT(BEiTViT): bias=bias, norm_cfg=norm_cfg, final_norm=final_norm, - with_cls_token=with_cls_token, - avg_token=avg_token, + out_type=out_type, + with_cls_token=True, frozen_stages=frozen_stages, - output_cls_token=output_cls_token, use_abs_pos_emb=use_abs_pos_emb, use_rel_pos_bias=use_rel_pos_bias, use_shared_rel_pos_bias=use_shared_rel_pos_bias, diff --git a/mmpretrain/models/selfsup/mae.py b/mmpretrain/models/selfsup/mae.py index 7f354308..178acff4 100644 --- a/mmpretrain/models/selfsup/mae.py +++ b/mmpretrain/models/selfsup/mae.py @@ -33,8 +33,17 @@ class MAEViT(VisionTransformer): Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. - output_cls_token (bool): Whether output the cls_token. If set True, - `with_cls_token` must be True. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. @@ -55,7 +64,7 @@ class MAEViT(VisionTransformer): drop_path_rate: float = 0, norm_cfg: dict = dict(type='LN', eps=1e-6), final_norm: bool = True, - output_cls_token: bool = True, + out_type: str = 'avg_featmap', interpolate_mode: str = 'bicubic', patch_cfg: dict = dict(), layer_cfgs: dict = dict(), @@ -70,7 +79,8 @@ class MAEViT(VisionTransformer): drop_path_rate=drop_path_rate, norm_cfg=norm_cfg, final_norm=final_norm, - output_cls_token=output_cls_token, + out_type=out_type, + with_cls_token=True, interpolate_mode=interpolate_mode, patch_cfg=patch_cfg, layer_cfgs=layer_cfgs, diff --git a/mmpretrain/models/selfsup/maskfeat.py b/mmpretrain/models/selfsup/maskfeat.py index c63143a3..c765051c 100644 --- a/mmpretrain/models/selfsup/maskfeat.py +++ b/mmpretrain/models/selfsup/maskfeat.py @@ -178,8 +178,17 @@ class MaskFeatViT(VisionTransformer): Defaults to ``dict(type='LN')``. final_norm (bool): Whether to add a additional layer to normalize final feature map. Defaults to True. - output_cls_token (bool): Whether output the cls_token. If set True, - `with_cls_token` must be True. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. interpolate_mode (str): Select the interpolate mode for position embeding vector resize. Defaults to "bicubic". patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. @@ -198,7 +207,7 @@ class MaskFeatViT(VisionTransformer): drop_path_rate: float = 0, norm_cfg: dict = dict(type='LN', eps=1e-6), final_norm: bool = True, - output_cls_token: bool = True, + out_type: str = 'avg_featmap', interpolate_mode: str = 'bicubic', patch_cfg: dict = dict(), layer_cfgs: dict = dict(), @@ -212,7 +221,8 @@ class MaskFeatViT(VisionTransformer): drop_path_rate=drop_path_rate, norm_cfg=norm_cfg, final_norm=final_norm, - output_cls_token=output_cls_token, + out_type=out_type, + with_cls_token=True, interpolate_mode=interpolate_mode, patch_cfg=patch_cfg, layer_cfgs=layer_cfgs, diff --git a/mmpretrain/models/utils/position_encoding.py b/mmpretrain/models/utils/position_encoding.py index 6b6fc7f4..a200c066 100644 --- a/mmpretrain/models/utils/position_encoding.py +++ b/mmpretrain/models/utils/position_encoding.py @@ -8,6 +8,14 @@ import torch.nn as nn from mmengine.model import BaseModule from mmengine.utils import digit_version +# After pytorch v1.10.0, use torch.meshgrid without indexing +# will raise extra warning. For more details, +# refers to https://github.com/pytorch/pytorch/issues/50276 +if digit_version(torch.__version__) >= digit_version('1.10.0'): + torch_meshgrid = partial(torch.meshgrid, indexing='ij') +else: + torch_meshgrid = torch.meshgrid + class ConditionalPositionEncoding(BaseModule): """The Conditional Position Encoding (CPE) module. @@ -137,7 +145,7 @@ def build_2d_sincos_position_embedding( h, w = patches_resolution grid_w = torch.arange(w, dtype=torch.float32) grid_h = torch.arange(h, dtype=torch.float32) - grid_w, grid_h = torch.meshgrid(grid_w, grid_h) + grid_w, grid_h = torch_meshgrid(grid_w, grid_h) assert embed_dims % 4 == 0, \ 'Embed dimension must be divisible by 4.' pos_dim = embed_dims // 4 diff --git a/tests/test_models/test_backbones/test_beit.py b/tests/test_models/test_backbones/test_beit.py index 5ed7287f..eed2be5d 100644 --- a/tests/test_models/test_backbones/test_beit.py +++ b/tests/test_models/test_backbones/test_beit.py @@ -82,18 +82,16 @@ class TestBEiT(TestCase): def test_forward(self): imgs = torch.randn(1, 3, 224, 224) - # test with output_cls_token cfg = deepcopy(self.cfg) - cfg['output_cls_token'] = True + cfg['out_type'] = 'cls_token' model = BEiTViT(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) - patch_token, cls_token = outs[-1] - self.assertEqual(patch_token.shape, (1, 768)) + cls_token = outs[-1] self.assertEqual(cls_token.shape, (1, 768)) - # test without output_cls_token + # test without output cls_token cfg = deepcopy(self.cfg) model = BEiTViT(**cfg) outs = model(imgs) @@ -104,7 +102,7 @@ class TestBEiT(TestCase): # test without average cfg = deepcopy(self.cfg) - cfg['avg_token'] = False + cfg['out_type'] = 'featmap' model = BEiTViT(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) diff --git a/tests/test_models/test_backbones/test_cspnet.py b/tests/test_models/test_backbones/test_cspnet.py index 656e9d00..9063e2fe 100644 --- a/tests/test_models/test_backbones/test_cspnet.py +++ b/tests/test_models/test_backbones/test_cspnet.py @@ -72,7 +72,6 @@ class TestCSPDarkNet(TestCase): def test_forward(self): imgs = torch.randn(3, 3, 224, 224) - # test without output_cls_token cfg = deepcopy(self.cfg) model = self.class_name(**cfg) outs = model(imgs) diff --git a/tests/test_models/test_backbones/test_deit.py b/tests/test_models/test_backbones/test_deit.py index 21914651..b2d096df 100644 --- a/tests/test_models/test_backbones/test_deit.py +++ b/tests/test_models/test_backbones/test_deit.py @@ -62,37 +62,19 @@ class TestDeiT(TestCase): def test_forward(self): imgs = torch.randn(1, 3, 224, 224) - # test with_cls_token=False - cfg = deepcopy(self.cfg) - cfg['with_cls_token'] = False - cfg['output_cls_token'] = True - with self.assertRaisesRegex(AssertionError, 'but got False'): - DistilledVisionTransformer(**cfg) - - cfg = deepcopy(self.cfg) - cfg['with_cls_token'] = False - cfg['output_cls_token'] = False - model = DistilledVisionTransformer(**cfg) - outs = model(imgs) - self.assertIsInstance(outs, tuple) - self.assertEqual(len(outs), 1) - patch_token = outs[-1] - self.assertEqual(patch_token.shape, (1, 192, 14, 14)) - - # test with output_cls_token + # test with output cls_token cfg = deepcopy(self.cfg) model = DistilledVisionTransformer(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) - patch_token, cls_token, dist_token = outs[-1] - self.assertEqual(patch_token.shape, (1, 192, 14, 14)) + cls_token, dist_token = outs[-1] self.assertEqual(cls_token.shape, (1, 192)) self.assertEqual(dist_token.shape, (1, 192)) - # test without output_cls_token + # test without output cls_token cfg = deepcopy(self.cfg) - cfg['output_cls_token'] = False + cfg['out_type'] = 'featmap' model = DistilledVisionTransformer(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) @@ -108,8 +90,7 @@ class TestDeiT(TestCase): self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 3) for out in outs: - patch_token, cls_token, dist_token = out - self.assertEqual(patch_token.shape, (1, 192, 14, 14)) + cls_token, dist_token = out self.assertEqual(cls_token.shape, (1, 192)) self.assertEqual(dist_token.shape, (1, 192)) @@ -118,14 +99,13 @@ class TestDeiT(TestCase): imgs2 = torch.randn(1, 3, 256, 256) imgs3 = torch.randn(1, 3, 256, 309) cfg = deepcopy(self.cfg) + cfg['out_type'] = 'featmap' model = DistilledVisionTransformer(**cfg) for imgs in [imgs1, imgs2, imgs3]: outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) - patch_token, cls_token, dist_token = outs[-1] + featmap = outs[-1] expect_feat_shape = (math.ceil(imgs.shape[2] / 16), math.ceil(imgs.shape[3] / 16)) - self.assertEqual(patch_token.shape, (1, 192, *expect_feat_shape)) - self.assertEqual(cls_token.shape, (1, 192)) - self.assertEqual(dist_token.shape, (1, 192)) + self.assertEqual(featmap.shape, (1, 192, *expect_feat_shape)) diff --git a/tests/test_models/test_backbones/test_deit3.py b/tests/test_models/test_backbones/test_deit3.py index 7e7aa485..7acb5072 100644 --- a/tests/test_models/test_backbones/test_deit3.py +++ b/tests/test_models/test_backbones/test_deit3.py @@ -116,13 +116,13 @@ class TestDeiT3(TestCase): # test with_cls_token=False cfg = deepcopy(self.cfg) cfg['with_cls_token'] = False - cfg['output_cls_token'] = True - with self.assertRaisesRegex(AssertionError, 'but got False'): + cfg['out_type'] = 'cls_token' + with self.assertRaisesRegex(ValueError, 'must be True'): DeiT3(**cfg) cfg = deepcopy(self.cfg) cfg['with_cls_token'] = False - cfg['output_cls_token'] = False + cfg['out_type'] = 'featmap' model = DeiT3(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) @@ -130,26 +130,15 @@ class TestDeiT3(TestCase): patch_token = outs[-1] self.assertEqual(patch_token.shape, (1, 768, 14, 14)) - # test with output_cls_token + # test with output cls_token cfg = deepcopy(self.cfg) model = DeiT3(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) - patch_token, cls_token = outs[-1] - self.assertEqual(patch_token.shape, (1, 768, 14, 14)) + cls_token = outs[-1] self.assertEqual(cls_token.shape, (1, 768)) - # test without output_cls_token - cfg = deepcopy(self.cfg) - cfg['output_cls_token'] = False - model = DeiT3(**cfg) - outs = model(imgs) - self.assertIsInstance(outs, tuple) - self.assertEqual(len(outs), 1) - patch_token = outs[-1] - self.assertEqual(patch_token.shape, (1, 768, 14, 14)) - # Test forward with multi out indices cfg = deepcopy(self.cfg) cfg['out_indices'] = [-3, -2, -1] @@ -158,8 +147,7 @@ class TestDeiT3(TestCase): self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 3) for out in outs: - patch_token, cls_token = out - self.assertEqual(patch_token.shape, (1, 768, 14, 14)) + cls_token = out self.assertEqual(cls_token.shape, (1, 768)) # Test forward with dynamic input size @@ -167,13 +155,13 @@ class TestDeiT3(TestCase): imgs2 = torch.randn(1, 3, 256, 256) imgs3 = torch.randn(1, 3, 256, 309) cfg = deepcopy(self.cfg) + cfg['out_type'] = 'featmap' model = DeiT3(**cfg) for imgs in [imgs1, imgs2, imgs3]: outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) - patch_token, cls_token = outs[-1] + featmap = outs[-1] expect_feat_shape = (math.ceil(imgs.shape[2] / 16), math.ceil(imgs.shape[3] / 16)) - self.assertEqual(patch_token.shape, (1, 768, *expect_feat_shape)) - self.assertEqual(cls_token.shape, (1, 768)) + self.assertEqual(featmap.shape, (1, 768, *expect_feat_shape)) diff --git a/tests/test_models/test_backbones/test_revvit.py b/tests/test_models/test_backbones/test_revvit.py index 1f234949..f18ca782 100644 --- a/tests/test_models/test_backbones/test_revvit.py +++ b/tests/test_models/test_backbones/test_revvit.py @@ -49,16 +49,6 @@ class TestRevVisionTransformer(TestCase): self.assertEqual(layer.attn.num_heads, 16) self.assertEqual(layer.ffn.feedforward_channels, 1024) - # Test out_indices - # TODO: to be implemented, current only support last layer - cfg = deepcopy(self.cfg) - cfg['out_indices'] = {1: 1} - with self.assertRaisesRegex(AssertionError, "get "): - RevVisionTransformer(**cfg) - cfg['out_indices'] = [13] - with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'): - RevVisionTransformer(**cfg) - # Test model structure cfg = deepcopy(self.cfg) model = RevVisionTransformer(**cfg) @@ -108,8 +98,8 @@ class TestRevVisionTransformer(TestCase): cfg['img_size'] = 384 model = RevVisionTransformer(**cfg) load_checkpoint(model, checkpoint, strict=True) - resized_pos_embed = timm_resize_pos_embed(pretrain_pos_embed, - model.pos_embed) + resized_pos_embed = timm_resize_pos_embed( + pretrain_pos_embed, model.pos_embed, num_tokens=0) self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed)) os.remove(checkpoint) @@ -119,7 +109,7 @@ class TestRevVisionTransformer(TestCase): cfg = deepcopy(self.cfg) cfg['with_cls_token'] = False - cfg['output_cls_token'] = False + cfg['out_type'] = 'avg_featmap' model = RevVisionTransformer(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) @@ -137,5 +127,5 @@ class TestRevVisionTransformer(TestCase): outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) - avg_token = outs[-1] - self.assertEqual(avg_token.shape, (1, 768 * 2)) + avg_featmap = outs[-1] + self.assertEqual(avg_featmap.shape, (1, 768 * 2)) diff --git a/tests/test_models/test_backbones/test_t2t_vit.py b/tests/test_models/test_backbones/test_t2t_vit.py index f0466ba1..76bfe9ce 100644 --- a/tests/test_models/test_backbones/test_t2t_vit.py +++ b/tests/test_models/test_backbones/test_t2t_vit.py @@ -94,13 +94,13 @@ class TestT2TViT(TestCase): # test with_cls_token=False cfg = deepcopy(self.cfg) cfg['with_cls_token'] = False - cfg['output_cls_token'] = True - with self.assertRaisesRegex(AssertionError, 'but got False'): + cfg['out_type'] = 'cls_token' + with self.assertRaisesRegex(ValueError, 'must be True'): T2T_ViT(**cfg) cfg = deepcopy(self.cfg) cfg['with_cls_token'] = False - cfg['output_cls_token'] = False + cfg['out_type'] = 'featmap' model = T2T_ViT(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) @@ -108,26 +108,15 @@ class TestT2TViT(TestCase): patch_token = outs[-1] self.assertEqual(patch_token.shape, (1, 384, 14, 14)) - # test with output_cls_token + # test with output cls_token cfg = deepcopy(self.cfg) model = T2T_ViT(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) - patch_token, cls_token = outs[-1] - self.assertEqual(patch_token.shape, (1, 384, 14, 14)) + cls_token = outs[-1] self.assertEqual(cls_token.shape, (1, 384)) - # test without output_cls_token - cfg = deepcopy(self.cfg) - cfg['output_cls_token'] = False - model = T2T_ViT(**cfg) - outs = model(imgs) - self.assertIsInstance(outs, tuple) - self.assertEqual(len(outs), 1) - patch_token = outs[-1] - self.assertEqual(patch_token.shape, (1, 384, 14, 14)) - # Test forward with multi out indices cfg = deepcopy(self.cfg) cfg['out_indices'] = [-3, -2, -1] @@ -136,22 +125,20 @@ class TestT2TViT(TestCase): self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 3) for out in outs: - patch_token, cls_token = out - self.assertEqual(patch_token.shape, (1, 384, 14, 14)) - self.assertEqual(cls_token.shape, (1, 384)) + self.assertEqual(out.shape, (1, 384)) # Test forward with dynamic input size imgs1 = torch.randn(1, 3, 224, 224) imgs2 = torch.randn(1, 3, 256, 256) imgs3 = torch.randn(1, 3, 256, 309) cfg = deepcopy(self.cfg) + cfg['out_type'] = 'featmap' model = T2T_ViT(**cfg) for imgs in [imgs1, imgs2, imgs3]: outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) - patch_token, cls_token = outs[-1] + patch_token = outs[-1] expect_feat_shape = (math.ceil(imgs.shape[2] / 16), math.ceil(imgs.shape[3] / 16)) self.assertEqual(patch_token.shape, (1, 384, *expect_feat_shape)) - self.assertEqual(cls_token.shape, (1, 384)) diff --git a/tests/test_models/test_backbones/test_vision_transformer.py b/tests/test_models/test_backbones/test_vision_transformer.py index 31a049fd..d6638ae3 100644 --- a/tests/test_models/test_backbones/test_vision_transformer.py +++ b/tests/test_models/test_backbones/test_vision_transformer.py @@ -126,13 +126,13 @@ class TestVisionTransformer(TestCase): # test with_cls_token=False cfg = deepcopy(self.cfg) cfg['with_cls_token'] = False - cfg['output_cls_token'] = True - with self.assertRaisesRegex(AssertionError, 'but got False'): + cfg['out_type'] = 'cls_token' + with self.assertRaisesRegex(ValueError, 'must be True'): VisionTransformer(**cfg) cfg = deepcopy(self.cfg) cfg['with_cls_token'] = False - cfg['output_cls_token'] = False + cfg['out_type'] = 'featmap' model = VisionTransformer(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) @@ -140,26 +140,15 @@ class TestVisionTransformer(TestCase): patch_token = outs[-1] self.assertEqual(patch_token.shape, (1, 768, 14, 14)) - # test with output_cls_token + # test with output cls_token cfg = deepcopy(self.cfg) model = VisionTransformer(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) - patch_token, cls_token = outs[-1] - self.assertEqual(patch_token.shape, (1, 768, 14, 14)) + cls_token = outs[-1] self.assertEqual(cls_token.shape, (1, 768)) - # test without output_cls_token - cfg = deepcopy(self.cfg) - cfg['output_cls_token'] = False - model = VisionTransformer(**cfg) - outs = model(imgs) - self.assertIsInstance(outs, tuple) - self.assertEqual(len(outs), 1) - patch_token = outs[-1] - self.assertEqual(patch_token.shape, (1, 768, 14, 14)) - # Test forward with multi out indices cfg = deepcopy(self.cfg) cfg['out_indices'] = [-3, -2, -1] @@ -168,22 +157,20 @@ class TestVisionTransformer(TestCase): self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 3) for out in outs: - patch_token, cls_token = out - self.assertEqual(patch_token.shape, (1, 768, 14, 14)) - self.assertEqual(cls_token.shape, (1, 768)) + self.assertEqual(out.shape, (1, 768)) # Test forward with dynamic input size imgs1 = torch.randn(1, 3, 224, 224) imgs2 = torch.randn(1, 3, 256, 256) imgs3 = torch.randn(1, 3, 256, 309) cfg = deepcopy(self.cfg) + cfg['out_type'] = 'featmap' model = VisionTransformer(**cfg) for imgs in [imgs1, imgs2, imgs3]: outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) - patch_token, cls_token = outs[-1] + patch_token = outs[-1] expect_feat_shape = (math.ceil(imgs.shape[2] / 16), math.ceil(imgs.shape[3] / 16)) self.assertEqual(patch_token.shape, (1, 768, *expect_feat_shape)) - self.assertEqual(cls_token.shape, (1, 768)) diff --git a/tests/test_models/test_selfsup/test_beit.py b/tests/test_models/test_selfsup/test_beit.py index cfa1a87c..4066a78e 100644 --- a/tests/test_models/test_selfsup/test_beit.py +++ b/tests/test_models/test_selfsup/test_beit.py @@ -35,9 +35,7 @@ class TestBEiT(TestCase): # test without mask fake_outputs = beit_backbone(fake_inputs, None) - assert len(fake_outputs[0]) == 2 - assert fake_outputs[0][0].shape == torch.Size([2, 768, 14, 14]) - assert fake_outputs[0][1].shape == torch.Size([2, 768]) + assert fake_outputs[0].shape == torch.Size([2, 768]) @pytest.mark.skipif( platform.system() == 'Windows', reason='Windows mem limit') @@ -111,10 +109,9 @@ class TestBEiT(TestCase): drop_path_rate=0., norm_cfg=dict(type='LN', eps=1e-6), final_norm=True, + out_type='featmap', with_cls_token=True, - avg_token=False, frozen_stages=-1, - output_cls_token=False, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, diff --git a/tests/test_models/test_selfsup/test_cae.py b/tests/test_models/test_selfsup/test_cae.py index c8954c75..fb5f5e59 100644 --- a/tests/test_models/test_selfsup/test_cae.py +++ b/tests/test_models/test_selfsup/test_cae.py @@ -25,9 +25,7 @@ def test_cae_vit(): # test without mask fake_outputs = cae_backbone(fake_inputs, None) - assert len(fake_outputs[0]) == 2 - assert fake_outputs[0][0].shape == torch.Size([1, 192, 14, 14]) - assert fake_outputs[0][1].shape == torch.Size([1, 192]) + assert fake_outputs[0].shape == torch.Size([1, 192]) @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') diff --git a/tests/test_models/test_selfsup/test_mae.py b/tests/test_models/test_selfsup/test_mae.py index fb73e165..8201d5f3 100644 --- a/tests/test_models/test_selfsup/test_mae.py +++ b/tests/test_models/test_selfsup/test_mae.py @@ -21,9 +21,7 @@ def test_mae_vit(): # test without mask fake_outputs = mae_backbone(fake_inputs, None) - assert len(fake_outputs[0]) == 2 - assert fake_outputs[0][0].shape == torch.Size([2, 768, 14, 14]) - assert fake_outputs[0][1].shape == torch.Size([2, 768]) + assert fake_outputs[0].shape == torch.Size([2, 768]) @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') diff --git a/tests/test_models/test_selfsup/test_maskfeat.py b/tests/test_models/test_selfsup/test_maskfeat.py index 351924b4..5feaa2a3 100644 --- a/tests/test_models/test_selfsup/test_maskfeat.py +++ b/tests/test_models/test_selfsup/test_maskfeat.py @@ -22,9 +22,7 @@ def test_maskfeat_vit(): # test without mask fake_outputs = maskfeat_backbone(fake_inputs, None) - assert len(fake_outputs[0]) == 2 - assert fake_outputs[0][0].shape == torch.Size([2, 768, 14, 14]) - assert fake_outputs[0][1].shape == torch.Size([2, 768]) + assert fake_outputs[0].shape == torch.Size([2, 768]) @pytest.mark.skipif( diff --git a/tests/test_models/test_selfsup/test_milan.py b/tests/test_models/test_selfsup/test_milan.py index b8c91ecf..12ad9aee 100644 --- a/tests/test_models/test_selfsup/test_milan.py +++ b/tests/test_models/test_selfsup/test_milan.py @@ -24,9 +24,7 @@ def test_milan_vit(): # test without mask fake_outputs = milan_backbone(fake_inputs, None) - assert len(fake_outputs[0]) == 2 - assert fake_outputs[0][0].shape == torch.Size([2, 768, 14, 14]) - assert fake_outputs[0][1].shape == torch.Size([2, 768]) + assert fake_outputs[0].shape == torch.Size([2, 768]) @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') diff --git a/tests/test_models/test_selfsup/test_mocov3.py b/tests/test_models/test_selfsup/test_mocov3.py index a3f1291f..b9d89a90 100644 --- a/tests/test_models/test_selfsup/test_mocov3.py +++ b/tests/test_models/test_selfsup/test_mocov3.py @@ -29,7 +29,6 @@ class TestMoCoV3(TestCase): with_last_bn_affine=False, with_last_bias=False, with_avg_pool=False, - vit_backbone=True, norm_cfg=dict(type='BN1d')) head = dict( type='MoCoV3Head', @@ -89,5 +88,4 @@ class TestMoCoV3(TestCase): # test extract fake_feats = alg(fake_inputs['inputs'][0], mode='tensor') - self.assertEqual(fake_feats[0][0].size(), torch.Size([2, 384, 14, 14])) - self.assertEqual(fake_feats[0][1].size(), torch.Size([2, 384])) + self.assertEqual(fake_feats[0].size(), torch.Size([2, 384])) diff --git a/tests/test_models/test_selfsup/test_target_generators.py b/tests/test_models/test_selfsup/test_target_generators.py index 08f53fbe..f53530b1 100644 --- a/tests/test_models/test_selfsup/test_target_generators.py +++ b/tests/test_models/test_selfsup/test_target_generators.py @@ -50,10 +50,9 @@ class TestVQKD(TestCase): drop_path_rate=0., norm_cfg=dict(type='LN', eps=1e-6), final_norm=True, + out_type='featmap', with_cls_token=True, - avg_token=False, frozen_stages=-1, - output_cls_token=False, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,