From 7a7b048f23922607dd88bfab6b123e4784875ef1 Mon Sep 17 00:00:00 2001 From: RenQin <45731309+soonera@users.noreply.github.com> Date: Tue, 6 Dec 2022 16:40:05 +0800 Subject: [PATCH] [Feature]: Add BEiT Support (#425) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Feature]: Add BEiT Support * [Fix]: fix bugs after update * [Fix]: fix bugs in backbone * [Refactor]: refactor config * [Feature]: Support BEiTv2 * [Fix]: Fix UT * [Fix]: rename some configs * [Fix]: fix beitv2neck * [Refactor]: refactor beitv2 * [Fix]: fix lint * refactor configs * refactor beitv2 * update configs * add dalle target generator * refactor for beitv1 * refactor rel_pos_bias of beit * update configs * update configs * update v1 configs * update v2 configs * refactoe layer decay * update unittest * fix lint * fix ut * add docstrings * rename * fix lint * add beit model and log links * fix lint * update according to review * update * update * update LearningRateDecayOptimWrapperConstructor related configs * update init and backbone * update neck and vqkd * refactor neck * fix lint * add some comments * fix typo Co-authored-by: 任琴 Co-authored-by: fangyixiao18 --- ...swin-base_ft-8xb256-coslr-100e_in1k-192.py | 8 +- ...-base-p16_ft-8xb128-coslr-100e-rpe_in1k.py | 9 +- .../vit-base-p16_ft-8xb128-coslr-100e_in1k.py | 4 +- .../selfsup/_base_/datasets/imagenet_beit.py | 52 ++++ .../_base_/datasets/imagenet_beitv2.py | 52 ++++ .../_base_/models/beit_vit-base-p16.py | 28 +++ .../_base_/models/beitv2_vit-base-p16.py | 59 +++++ configs/selfsup/beit/README.md | 60 +++++ ...vit-base-p16_8xb256-amp-coslr-300e_in1k.py | 56 +++++ .../vit-base-p16_ft-8xb128-coslr-100e_in1k.py | 134 ++++++++++ configs/selfsup/beit/metafile.yml | 35 +++ configs/selfsup/beitv2/README.md | 60 +++++ ...vit-base-p16_8xb256-amp-coslr-300e_in1k.py | 55 +++++ .../vit-base-p16_ft-8xb128-coslr-100e_in1k.py | 129 ++++++++++ .../vit-base-p16_ft-8xb128-coslr-30e_in1k.py | 126 ++++++++++ configs/selfsup/beitv2/metafile.yml | 35 +++ docs/en/model_zoo.md | 11 + docs/zh_cn/model_zoo.md | 11 + .../layer_decay_optim_wrapper_constructor.py | 45 ++-- mmselfsup/models/algorithms/__init__.py | 7 +- mmselfsup/models/algorithms/beit.py | 67 +++++ mmselfsup/models/backbones/__init__.py | 3 +- mmselfsup/models/backbones/beit_vit.py | 188 ++++++++++++++ mmselfsup/models/heads/__init__.py | 9 +- mmselfsup/models/heads/beitv1_head.py | 55 +++++ mmselfsup/models/heads/beitv2_head.py | 56 +++++ mmselfsup/models/losses/__init__.py | 3 +- mmselfsup/models/losses/beit_loss.py | 39 +++ mmselfsup/models/necks/__init__.py | 3 +- mmselfsup/models/necks/beitv2_neck.py | 153 ++++++++++++ .../models/target_generators/__init__.py | 6 +- mmselfsup/models/target_generators/dall_e.py | 180 ++++++++++++++ mmselfsup/models/target_generators/vqkd.py | 104 ++++++++ mmselfsup/models/utils/__init__.py | 27 +- mmselfsup/models/utils/data_preprocessor.py | 113 ++++++++- mmselfsup/models/utils/vector_quantizer.py | 232 ++++++++++++++++++ model-index.yml | 1 + ...t_layer_decay_optim_wrapper_constructor.py | 25 +- .../test_algorithms/test_beitv1.py | 65 +++++ .../test_algorithms/test_beitv2.py | 99 ++++++++ .../test_backbones/test_beit_vit.py | 28 +++ .../test_dalle.py | 2 +- .../test_target_generators/test_vqkd.py | 39 +++ .../test_utils/test_data_preprocessor.py | 52 +++- 44 files changed, 2475 insertions(+), 50 deletions(-) create mode 100644 configs/selfsup/_base_/datasets/imagenet_beit.py create mode 100644 configs/selfsup/_base_/datasets/imagenet_beitv2.py create mode 100644 configs/selfsup/_base_/models/beit_vit-base-p16.py create mode 100644 configs/selfsup/_base_/models/beitv2_vit-base-p16.py create mode 100644 configs/selfsup/beit/README.md create mode 100644 configs/selfsup/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k.py create mode 100644 configs/selfsup/beit/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py create mode 100644 configs/selfsup/beit/metafile.yml create mode 100644 configs/selfsup/beitv2/README.md create mode 100644 configs/selfsup/beitv2/beitv2_vit-base-p16_8xb256-amp-coslr-300e_in1k.py create mode 100644 configs/selfsup/beitv2/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py create mode 100644 configs/selfsup/beitv2/classification/vit-base-p16_ft-8xb128-coslr-30e_in1k.py create mode 100644 configs/selfsup/beitv2/metafile.yml create mode 100644 mmselfsup/models/algorithms/beit.py create mode 100644 mmselfsup/models/backbones/beit_vit.py create mode 100644 mmselfsup/models/heads/beitv1_head.py create mode 100644 mmselfsup/models/heads/beitv2_head.py create mode 100644 mmselfsup/models/losses/beit_loss.py create mode 100644 mmselfsup/models/necks/beitv2_neck.py create mode 100644 mmselfsup/models/target_generators/dall_e.py create mode 100644 mmselfsup/models/target_generators/vqkd.py create mode 100644 mmselfsup/models/utils/vector_quantizer.py create mode 100644 tests/test_models/test_algorithms/test_beitv1.py create mode 100644 tests/test_models/test_algorithms/test_beitv2.py create mode 100644 tests/test_models/test_backbones/test_beit_vit.py rename tests/test_models/{test_utils => test_target_generators}/test_dalle.py (86%) create mode 100644 tests/test_models/test_target_generators/test_vqkd.py diff --git a/configs/benchmarks/classification/imagenet/swin-base_ft-8xb256-coslr-100e_in1k-192.py b/configs/benchmarks/classification/imagenet/swin-base_ft-8xb256-coslr-100e_in1k-192.py index c23d680e..e4c1c8a3 100644 --- a/configs/benchmarks/classification/imagenet/swin-base_ft-8xb256-coslr-100e_in1k-192.py +++ b/configs/benchmarks/classification/imagenet/swin-base_ft-8xb256-coslr-100e_in1k-192.py @@ -19,14 +19,14 @@ optim_wrapper = dict( optimizer=dict( type='AdamW', lr=5e-3, model_type='swin', layer_decay_rate=0.9), clip_grad=dict(max_norm=5.0), + constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor', paramwise_cfg=dict( - norm_decay_mult=0.0, - bias_decay_mult=0.0, custom_keys={ + '.norm': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), '.absolute_pos_embed': dict(decay_mult=0.0), '.relative_position_bias_table': dict(decay_mult=0.0) - }), - constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor') + })) # learning rate scheduler param_scheduler = [ diff --git a/configs/benchmarks/classification/imagenet/vit-base-p16_ft-8xb128-coslr-100e-rpe_in1k.py b/configs/benchmarks/classification/imagenet/vit-base-p16_ft-8xb128-coslr-100e-rpe_in1k.py index 90966787..31ec3107 100644 --- a/configs/benchmarks/classification/imagenet/vit-base-p16_ft-8xb128-coslr-100e-rpe_in1k.py +++ b/configs/benchmarks/classification/imagenet/vit-base-p16_ft-8xb128-coslr-100e-rpe_in1k.py @@ -97,7 +97,14 @@ optim_wrapper = dict( weight_decay=0.05, model_type='vit', # layer-wise lr decay type layer_decay_rate=0.65), # layer-wise lr decay factor - constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor') + constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor', + paramwise_cfg=dict( + custom_keys={ + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0) + })) # learning rate scheduler param_scheduler = [ diff --git a/configs/benchmarks/classification/imagenet/vit-base-p16_ft-8xb128-coslr-100e_in1k.py b/configs/benchmarks/classification/imagenet/vit-base-p16_ft-8xb128-coslr-100e_in1k.py index 8f5b13d4..c9d247bb 100644 --- a/configs/benchmarks/classification/imagenet/vit-base-p16_ft-8xb128-coslr-100e_in1k.py +++ b/configs/benchmarks/classification/imagenet/vit-base-p16_ft-8xb128-coslr-100e_in1k.py @@ -88,9 +88,9 @@ optim_wrapper = dict( layer_decay_rate=0.65), # layer-wise lr decay factor constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor', paramwise_cfg=dict( - norm_decay_mult=0.0, - bias_decay_mult=0.0, custom_keys={ + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), '.cls_token': dict(decay_mult=0.0), '.pos_embed': dict(decay_mult=0.0) })) diff --git a/configs/selfsup/_base_/datasets/imagenet_beit.py b/configs/selfsup/_base_/datasets/imagenet_beit.py new file mode 100644 index 00000000..fd539a31 --- /dev/null +++ b/configs/selfsup/_base_/datasets/imagenet_beit.py @@ -0,0 +1,52 @@ +# dataset settings +dataset_type = 'mmcls.ImageNet' +data_root = 'data/imagenet/' +file_client_args = dict(backend='disk') +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='ColorJitter', + brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='RandomResizedCropAndInterpolationWithTwoPic', + size=224, + second_size=112, + interpolation='bicubic', + second_interpolation='lanczos', + scale=(0.08, 1.0)), + dict( + type='BEiTMaskGenerator', + input_size=(14, 14), + num_masking_patches=75, + max_num_patches=None, + min_num_patches=16), + dict( + type='PackSelfSupInputs', + algorithm_keys=['mask'], + meta_keys=['img_path']) +] + +data_preprocessor = dict( + type='TwoNormDataPreprocessor', + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + second_mean=(-20.4, -20.4, -20.4), + second_std=(204., 204., 204.), + bgr_to_rgb=True) + +train_dataloader = dict( + batch_size=256, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='meta/train.txt', + data_prefix=dict(img_path='train/'), + pipeline=train_pipeline)) diff --git a/configs/selfsup/_base_/datasets/imagenet_beitv2.py b/configs/selfsup/_base_/datasets/imagenet_beitv2.py new file mode 100644 index 00000000..7267817c --- /dev/null +++ b/configs/selfsup/_base_/datasets/imagenet_beitv2.py @@ -0,0 +1,52 @@ +# dataset settings +dataset_type = 'mmcls.ImageNet' +data_root = 'data/imagenet/' +file_client_args = dict(backend='disk') +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='ColorJitter', + brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='RandomResizedCropAndInterpolationWithTwoPic', + size=224, + second_size=224, + interpolation='bicubic', + second_interpolation='bicubic', + scale=(0.2, 1.0)), + dict( + type='BEiTMaskGenerator', + input_size=(14, 14), + num_masking_patches=75, + max_num_patches=75, + min_num_patches=16), + dict( + type='PackSelfSupInputs', + algorithm_keys=['mask'], + meta_keys=['img_path']) +] + +data_preprocessor = dict( + type='TwoNormDataPreprocessor', + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + second_mean=(127.5, 127.5, 127.5), + second_std=(127.5, 127.5, 127.5), + bgr_to_rgb=True) + +train_dataloader = dict( + batch_size=256, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='meta/train.txt', + data_prefix=dict(img_path='train/'), + pipeline=train_pipeline)) diff --git a/configs/selfsup/_base_/models/beit_vit-base-p16.py b/configs/selfsup/_base_/models/beit_vit-base-p16.py new file mode 100644 index 00000000..cbc35be8 --- /dev/null +++ b/configs/selfsup/_base_/models/beit_vit-base-p16.py @@ -0,0 +1,28 @@ +# model settings +model = dict( + type='BEiT', + backbone=dict( + type='BEiTViT', + arch='base', + patch_size=16, + drop_path_rate=0.1, + final_norm=True, + layer_scale_init_value=0.1, + init_cfg=[ + dict(type='TruncNormal', std=0.02, layer='Linear'), + dict(type='TruncNormal', std=0.02, layer='Conv2d'), + dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0) + ]), + neck=None, + head=dict( + type='BEiTV1Head', + embed_dims=768, + num_embed=8192, + loss=dict(type='BEiTLoss')), + target_generator=dict( + type='DALL-E', + init_cfg=dict( + type='Pretrained', + checkpoint= # noqa: E251 + 'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/dalle_encoder.pth', # noqa: E501 + ))) diff --git a/configs/selfsup/_base_/models/beitv2_vit-base-p16.py b/configs/selfsup/_base_/models/beitv2_vit-base-p16.py new file mode 100644 index 00000000..96478fe9 --- /dev/null +++ b/configs/selfsup/_base_/models/beitv2_vit-base-p16.py @@ -0,0 +1,59 @@ +# model settings +vqkd_encoder = dict( + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + with_cls_token=True, + avg_token=False, + frozen_stages=-1, + output_cls_token=False, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + layer_scale_init_value=0., + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None) + +layer_scale_init_value = 0.1 +drop_path_rate = 0. # 0. for 300 epochs and 0.1 for 1600 epochs. +model = dict( + type='BEiT', + backbone=dict( + type='BEiTViT', + arch='base', + patch_size=16, + out_indices=[-4, -1], + drop_path_rate=drop_path_rate, + final_norm=False, + layer_scale_init_value=layer_scale_init_value, + init_cfg=[ + dict(type='TruncNormal', std=0.02, layer='Linear'), + dict(type='TruncNormal', std=0.02, layer='Conv2d'), + dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0) + ]), + neck=dict( + type='BEiTV2Neck', + num_layers=2, + early_layers=9, + backbone_arch='base', + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, + ), + head=dict( + type='BEiTV2Head', + embed_dims=768, + num_embed=8192, + loss=dict(type='BEiTLoss')), + target_generator=dict( + type='VQKD', + encoder_config=vqkd_encoder, + init_cfg=dict( + type='Pretrained', checkpoint='beit_ckpt/vqkd_encoder.pth'))) diff --git a/configs/selfsup/beit/README.md b/configs/selfsup/beit/README.md new file mode 100644 index 00000000..4be656d9 --- /dev/null +++ b/configs/selfsup/beit/README.md @@ -0,0 +1,60 @@ +# BEiT + +> [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254) + + + +## Abstract + +We introduce a self-supervised vision representation model BEiT, which stands for Bidirectional Encoder representation from Image Transformers. Following BERT developed in the natural language processing area, we propose a masked image modeling task to pretrain vision Transformers. Specifically, each image has two views in our pre-training, i.e, image patches (such as 16x16 pixels), and visual tokens (i.e., discrete tokens). We first "tokenize" the original image into visual tokens. Then we randomly mask some image patches and fed them into the backbone Transformer. The pre-training objective is to recover the original visual tokens based on the corrupted image patches. After pre-training BEiT, we directly fine-tune the model parameters on downstream tasks by appending task layers upon the pretrained encoder. Experimental results on image classification and semantic segmentation show that our model achieves competitive results with previous pre-training methods. For example, base-size BEiT achieves 83.2% top-1 accuracy on ImageNet-1K, significantly outperforming from-scratch DeiT training (81.8%) with the same setup. Moreover, large-size BEiT obtains 86.3% only using ImageNet-1K, even outperforming ViT-L with supervised pre-training on ImageNet-22K (85.2%). + +
+ +
+ +## Models and Benchmarks + +Here, we report the results of the model on ImageNet, the details are below: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
AlgorithmBackboneEpochBatch SizeResults (Top-1 %)Links
Linear EvalFine-tuningPretrainLinear EvalFine-tuning
BEiTViT-base3002048/83.1config | model | log/config | model | log
+ +## Citation + +```bibtex +@inproceedings{bao2022beit, + title={{BE}iT: {BERT} Pre-Training of Image Transformers}, + author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei}, + booktitle={International Conference on Learning Representations}, + year={2022}, +} +``` diff --git a/configs/selfsup/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k.py b/configs/selfsup/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k.py new file mode 100644 index 00000000..8c4c2d79 --- /dev/null +++ b/configs/selfsup/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k.py @@ -0,0 +1,56 @@ +_base_ = [ + '../_base_/models/beit_vit-base-p16.py', + '../_base_/datasets/imagenet_beit.py', + '../_base_/schedules/adamw_coslr-300e_in1k.py', + '../_base_/default_runtime.py', +] + +# optimizer wrapper +optimizer = dict( + type='AdamW', lr=1.5e-3, betas=(0.9, 0.999), weight_decay=0.05) + +optim_wrapper = dict( + type='AmpOptimWrapper', + loss_scale='dynamic', + optimizer=optimizer, + clip_grad=dict(max_norm=3.0), + paramwise_cfg=dict( + custom_keys={ + # the following configurations are designed for BEiT + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + eta_min=1e-5, + by_epoch=True, + begin=10, + end=300, + convert_to_iter_based=True) +] + +# runtime settings +default_hooks = dict( + logger=dict(type='LoggerHook', interval=100), + # only keeps the latest 3 checkpoints + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3)) + +# randomness +randomness = dict(seed=0, diff_rank_seed=True) + +find_unused_parameters = True diff --git a/configs/selfsup/beit/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py b/configs/selfsup/beit/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py new file mode 100644 index 00000000..dc73fe40 --- /dev/null +++ b/configs/selfsup/beit/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py @@ -0,0 +1,134 @@ +# mmcls:: means we use the default settings from MMClassification +_base_ = [ + 'mmcls::_base_/datasets/imagenet_bs64_swin_224.py', + 'mmcls::_base_/schedules/imagenet_bs1024_adamw_swin.py', + 'mmcls::_base_/default_runtime.py' +] + +data_preprocessor = dict( + num_classes=1000, + mean=[127.5, 127.5, 127.5], + std=[127.5, 127.5, 127.5], + to_rgb=True, +) + +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='BEiT', + arch='base', + img_size=224, + patch_size=16, + drop_path_rate=0.1, + avg_token=True, + output_cls_token=False, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False), + neck=None, + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=768, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + init_cfg=[dict(type='TruncNormal', layer='Linear', std=0.02)]), + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) + ])) + +file_client_args = dict(backend='disk') +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='RandomResizedCrop', + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=0.3333333333333333, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type='PackClsInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='ResizeEdge', + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='PackClsInputs') +] + +train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +# optimizer wrapper +optim_wrapper = dict( + optimizer=dict( + type='AdamW', + lr=4e-3, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999), + model_type='vit', + layer_decay_rate=0.65), + constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor', + paramwise_cfg=dict( + _delete_=True, + custom_keys={ + # the following configurations are designed for BEiT + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=20, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + by_epoch=True, + begin=20, + end=100, + eta_min=1e-6, + convert_to_iter_based=True) +] + +# runtime settings +default_hooks = dict( + # save checkpoint per epoch. + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2)) + +train_cfg = dict(by_epoch=True, max_epochs=100) + +randomness = dict(seed=0) diff --git a/configs/selfsup/beit/metafile.yml b/configs/selfsup/beit/metafile.yml new file mode 100644 index 00000000..8e6153c4 --- /dev/null +++ b/configs/selfsup/beit/metafile.yml @@ -0,0 +1,35 @@ +Collections: + - Name: BEiT + Metadata: + Training Data: ImageNet-1k + Training Techniques: + - AdamW + Training Resources: 8x A100-80G GPUs + Architecture: + - ViT + Paper: + URL: https://arxiv.org/abs/2106.08254 + Title: "BEiT: BERT Pre-Training of Image Transformers" + README: configs/selfsup/beit/README.md + +Models: + - Name: beit_vit-base-p16_8xb256-amp-coslr-300e_in1k + In Collection: BEiT + Metadata: + Epochs: 300 + Batch Size: 2048 + Results: null + Config: configs/selfsup/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k.py + Weights: https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k_20221128-ab79e626.pth + Downstream: + - Type: Image Classification + Metadata: + Epochs: 100 + Batch Size: 1024 + Results: + - Task: Fine-tuning + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 83.1 + Config: configs/selfsup/beit/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py + Weights: https://download.openmmlab.com/mmselfsup/1.x/beit/beit_vit-base-p16_8xb256-amp-coslr-300e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221128-0ca393e9.pth diff --git a/configs/selfsup/beitv2/README.md b/configs/selfsup/beitv2/README.md new file mode 100644 index 00000000..0e85fdac --- /dev/null +++ b/configs/selfsup/beitv2/README.md @@ -0,0 +1,60 @@ +# BEiT + +> [BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers](https://arxiv.org/abs/2208.06366) + + + +## Abstract + +Masked image modeling (MIM) has demonstrated impressive results in self-supervised representation learning by recovering corrupted image patches. However, most existing studies operate on low-level image pixels, which hinders the exploitation of high-level semantics for representation models. In this work, we propose to use a semantic-rich visual tokenizer as the reconstruction target for masked prediction, providing a systematic way to promote MIM from pixel-level to semantic-level. Specifically, we propose vector-quantized knowledge distillation to train the tokenizer, which discretizes a continuous semantic space to compact codes. We then pretrain vision Transformers by predicting the original visual tokens for the masked image patches. Furthermore, we introduce a patch aggregation strategy which associates discrete image patches to enhance global semantic representation. Experiments on image classification and semantic segmentation show that BEiT v2 outperforms all compared MIM methods. On ImageNet-1K (224 size), the base-size BEiT v2 achieves 85.5% top-1 accuracy for fine-tuning and 80.1% top-1 accuracy for linear probing. The large-size BEiT v2 obtains 87.3% top-1 accuracy for ImageNet-1K (224 size) fine-tuning, and 56.7% mIoU on ADE20K for semantic segmentation. + +
+ +
+ +## Models and Benchmarks + +Here, we report the results of the model on ImageNet, the details are below: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
AlgorithmBackboneEpochBatch SizeResults (Top-1 %)Links
Linear EvalFine-tuningPretrainLinear EvalFine-tuning
BEiTViT-base3002048/config | model | log/config | model | log
+ +## Citation + +```bibtex +@article{beitv2, + title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers}, + author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei}, + journal={ArXiv}, + year={2022} +} +``` diff --git a/configs/selfsup/beitv2/beitv2_vit-base-p16_8xb256-amp-coslr-300e_in1k.py b/configs/selfsup/beitv2/beitv2_vit-base-p16_8xb256-amp-coslr-300e_in1k.py new file mode 100644 index 00000000..bdac4004 --- /dev/null +++ b/configs/selfsup/beitv2/beitv2_vit-base-p16_8xb256-amp-coslr-300e_in1k.py @@ -0,0 +1,55 @@ +_base_ = [ + '../_base_/models/beitv2_vit-base-p16.py', + '../_base_/datasets/imagenet_beitv2.py', + '../_base_/schedules/adamw_coslr-300e_in1k.py', + '../_base_/default_runtime.py', +] + +# optimizer wrapper +optimizer = dict(type='AdamW', lr=1.5e-3, betas=(0.9, 0.98), weight_decay=0.05) + +optim_wrapper = dict( + type='AmpOptimWrapper', + loss_scale='dynamic', + optimizer=optimizer, + clip_grad=dict(max_norm=3.0), + paramwise_cfg=dict( + custom_keys={ + # the following configurations are designed for BEiT + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + eta_min=1e-5, + by_epoch=True, + begin=10, + end=300, + convert_to_iter_based=True) +] + +# runtime settings +default_hooks = dict( + logger=dict(type='LoggerHook', interval=100), + # only keeps the latest 3 checkpoints + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3)) + +# randomness +randomness = dict(seed=0, diff_rank_seed=True) + +find_unused_parameters = True diff --git a/configs/selfsup/beitv2/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py b/configs/selfsup/beitv2/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py new file mode 100644 index 00000000..dd2e33d0 --- /dev/null +++ b/configs/selfsup/beitv2/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py @@ -0,0 +1,129 @@ +# mmcls:: means we use the default settings from MMClassification +_base_ = [ + 'mmcls::_base_/datasets/imagenet_bs64_swin_224.py', + 'mmcls::_base_/schedules/imagenet_bs1024_adamw_swin.py', + 'mmcls::_base_/default_runtime.py' +] + +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='BEiT', + arch='base', + img_size=224, + patch_size=16, + # 0.2 for 1600 epochs pretrained models and 0.1 for 300 epochs. + drop_path_rate=0.2, + avg_token=True, + output_cls_token=False, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False), + neck=None, + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=768, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + init_cfg=[dict(type='TruncNormal', layer='Linear', std=0.02)]), + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) + ])) + +file_client_args = dict(backend='disk') +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='RandomResizedCrop', + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=0.3333333333333333, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type='PackClsInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='ResizeEdge', + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='PackClsInputs') +] + +train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +# optimizer wrapper +optim_wrapper = dict( + optimizer=dict( + type='AdamW', + lr=5e-4, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999), + model_type='vit', + # 0.6 for 1600 epochs pretrained models and 0.65 for 300 epochs + layer_decay_rate=0.6), + constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor', + paramwise_cfg=dict( + _delete_=True, + custom_keys={ + # the following configurations are designed for BEiT + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=20, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + by_epoch=True, + begin=20, + end=100, + eta_min=1e-6, + convert_to_iter_based=True) +] + +# runtime settings +default_hooks = dict( + # save checkpoint per epoch. + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2)) + +train_cfg = dict(by_epoch=True, max_epochs=100) + +randomness = dict(seed=0) diff --git a/configs/selfsup/beitv2/classification/vit-base-p16_ft-8xb128-coslr-30e_in1k.py b/configs/selfsup/beitv2/classification/vit-base-p16_ft-8xb128-coslr-30e_in1k.py new file mode 100644 index 00000000..08031b64 --- /dev/null +++ b/configs/selfsup/beitv2/classification/vit-base-p16_ft-8xb128-coslr-30e_in1k.py @@ -0,0 +1,126 @@ +# mmcls:: means we use the default settings from MMClassification +_base_ = [ + 'mmcls::_base_/datasets/imagenet_bs64_swin_224.py', + 'mmcls::_base_/schedules/imagenet_bs1024_adamw_swin.py', + 'mmcls::_base_/default_runtime.py' +] +# Fine-tuning 30 epoch is for models which have intermediate fine-tuning +# on ImageNet-21k after self-supervised pretrain. + +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='BEiT', + arch='base', + img_size=224, + patch_size=16, + drop_path_rate=0.1, + avg_token=True, + output_cls_token=False, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + use_shared_rel_pos_bias=False), + neck=None, + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=768, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + init_cfg=[dict(type='TruncNormal', layer='Linear', std=0.02)]), +) + +file_client_args = dict(backend='disk') +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='RandomResizedCrop', + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=0.3333333333333333, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type='PackClsInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='ResizeEdge', + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='PackClsInputs') +] + +train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +# optimizer wrapper +optim_wrapper = dict( + optimizer=dict( + type='AdamW', + lr=5e-5, + weight_decay=0.05, + eps=1e-8, + betas=(0.9, 0.999), + model_type='vit', # layer-wise lr decay type + layer_decay_rate=0.75), # layer-wise lr decay factor + constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor', + paramwise_cfg=dict( + _delete_=True, + custom_keys={ + # the following configurations are designed for BEiTs + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + 'q_bias': dict(decay_mult=0.0), + 'v_bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + '.gamma': dict(decay_mult=0.0), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=20, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + by_epoch=True, + begin=20, + end=30, + eta_min=1e-6, + convert_to_iter_based=True) +] + +# runtime settings +default_hooks = dict( + # save checkpoint per epoch. + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2)) + +train_cfg = dict(by_epoch=True, max_epochs=30) + +randomness = dict(seed=0) diff --git a/configs/selfsup/beitv2/metafile.yml b/configs/selfsup/beitv2/metafile.yml new file mode 100644 index 00000000..826ce32c --- /dev/null +++ b/configs/selfsup/beitv2/metafile.yml @@ -0,0 +1,35 @@ +Collections: + - Name: BEiTv2 + Metadata: + Training Data: ImageNet-1k + Training Techniques: + - AdamW + Training Resources: 8x A100-80G GPUs + Architecture: + - ViT + Paper: + URL: https://arxiv.org/abs/2208.06366 + Title: 'BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers' + README: configs/selfsup/beitv2/README.md + +Models: + - Name: beitv2_vit-base-p16_8xb256-amp-coslr-300e_in1k + In Collection: BEiTv2 + Metadata: + Epochs: 300 + Batch Size: 2048 + Results: null + Config: configs/selfsup/beitv2/beitv2_vit-base-p16_8xb256-amp-coslr-300e_in1k.py + Weights: + Downstream: + - Type: Image Classification + Metadata: + Epochs: 100 + Batch Size: 1024 + Results: + - Task: Fine-tuning + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: + Config: + Weights: diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md index f1e47e68..75bb9ee6 100644 --- a/docs/en/model_zoo.md +++ b/docs/en/model_zoo.md @@ -393,5 +393,16 @@ ImageNet has multiple versions, but the most commonly used one is ILSVRC 2012. T / config | model | log + + BEiT + ViT-base + 300 + 2048 + / + 83.1 + config | model | log + / + config | model | log + diff --git a/docs/zh_cn/model_zoo.md b/docs/zh_cn/model_zoo.md index bb7b7d36..091ee064 100644 --- a/docs/zh_cn/model_zoo.md +++ b/docs/zh_cn/model_zoo.md @@ -393,5 +393,16 @@ ImageNet 有多个版本,不过最常用的是 ILSVRC 2012。我们提供了 / config | model | log + + BEiT + ViT-base + 300 + 2048 + / + 83.1 + config | model | log + / + config | model | log + diff --git a/mmselfsup/engine/optimizers/layer_decay_optim_wrapper_constructor.py b/mmselfsup/engine/optimizers/layer_decay_optim_wrapper_constructor.py index 18e2c07c..e262c0b8 100644 --- a/mmselfsup/engine/optimizers/layer_decay_optim_wrapper_constructor.py +++ b/mmselfsup/engine/optimizers/layer_decay_optim_wrapper_constructor.py @@ -66,11 +66,8 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor): Note: Currently, this optimizer constructor is built for ViT and Swin. - In addition to applying layer-wise learning rate decay schedule, this - module will not apply weight decay to ``normalization parameters``, - ``bias``, ``position embedding``, ``class token``, and - ``relative position bias table, automatically. What's more, the - ``paramwise_cfg`` in the base module will be ignored. + In addition to applying layer-wise learning rate decay schedule, the + paramwise_cfg only supports weight decay customization. """ def add_params(self, params: List[dict], module: nn.Module, @@ -87,14 +84,27 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor): optimizer_cfg (dict): The configuration of optimizer. prefix (str): The prefix of the module. """ + # get param-wise options + custom_keys = self.paramwise_cfg.get('custom_keys', {}) + # first sort with alphabet order and then sort with reversed len of str + sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) + + # get logger logger = MMLogger.get_current_instance() + logger.warning( + 'LearningRateDecayOptimWrapperConstructor is refactored in ' + 'v1.0.0rc4, which need to configure zero weight decay manually. ' + 'The previous versions would set zero weight decay according to ' + 'the dimension of parameter. Please specify weight decay settings ' + 'of different layers in config if needed.') # Check if self.param_cfg is not None if len(self.paramwise_cfg) > 0: - logger.info('The paramwise_cfg will be ignored, and normalization \ - parameters, bias, position embedding, class token and \ - relative position bias table will not be decayed by \ - default.') + logger.info( + 'The paramwise_cfg only supports weight decay customization ' + 'in LearningRateDecayOptimWrapperConstructor, please indicate ' + 'the specific weight decay settings of different layers in ' + 'config if needed.') model_type = optimizer_cfg.pop('model_type', None) # model_type should not be None @@ -111,24 +121,25 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor): elif model_type == 'swin': num_layers = sum(module.backbone.depths) + 2 - weight_decay = self.base_wd # if layer_decay_rate is not provided, not decay decay_rate = optimizer_cfg.pop('layer_decay_rate', 1.0) parameter_groups = {} + assert self.base_wd is not None for name, param in module.named_parameters(): if not param.requires_grad: continue # frozen weights - # will not decay normalization params, bias, position embedding - # class token, relative position bias table - if len(param.shape) == 1 or name.endswith('.bias') or name in ( - 'backbone.pos_embed', 'backbone.cls_token' - ) or 'relative_position_bias_table' in name: + + this_weight_decay = self.base_wd + for key in sorted_keys: + if key in name: + decay_mult = custom_keys[key].get('decay_mult', 1.) + this_weight_decay = self.base_wd * decay_mult + + if this_weight_decay == 0: group_name = 'no_decay' - this_weight_decay = 0. else: group_name = 'decay' - this_weight_decay = weight_decay if model_type == 'vit': layer_id = get_layer_id_for_vit(name, num_layers) diff --git a/mmselfsup/models/algorithms/__init__.py b/mmselfsup/models/algorithms/__init__.py index 7f924a1b..2a13fb5e 100644 --- a/mmselfsup/models/algorithms/__init__.py +++ b/mmselfsup/models/algorithms/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .barlowtwins import BarlowTwins from .base import BaseModel +from .beit import BEiT from .byol import BYOL from .cae import CAE from .deepcluster import DeepCluster @@ -19,7 +20,7 @@ from .simsiam import SimSiam from .swav import SwAV __all__ = [ - 'BaseModel', 'BarlowTwins', 'BYOL', 'DeepCluster', 'DenseCL', 'MoCo', - 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR', 'SimSiam', 'SwAV', - 'MAE', 'MoCoV3', 'SimMIM', 'CAE', 'MaskFeat' + 'BaseModel', 'BarlowTwins', 'BEiT', 'BYOL', 'DeepCluster', 'DenseCL', + 'MoCo', 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR', 'SimSiam', + 'SwAV', 'MAE', 'MoCoV3', 'SimMIM', 'CAE', 'MaskFeat' ] diff --git a/mmselfsup/models/algorithms/beit.py b/mmselfsup/models/algorithms/beit.py new file mode 100644 index 00000000..44df5847 --- /dev/null +++ b/mmselfsup/models/algorithms/beit.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + +import torch + +from mmselfsup.registry import MODELS +from mmselfsup.structures import SelfSupDataSample +from .base import BaseModel + + +@MODELS.register_module() +class BEiT(BaseModel): + """BEiT v1/v2. + + Implementation of `BEiT: BERT Pre-Training of Image Transformers + `_. Implementation of `BEiT v2: Masked + Image Modeling with Vector-Quantized Visual Tokenizers + `_. + """ + + def loss(self, batch_inputs: List[torch.Tensor], + data_samples: List[SelfSupDataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + batch_inputs (List[torch.Tensor]): The input images. + data_samples (List[SelfSupDataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + mask = torch.stack( + [data_sample.mask.value for data_sample in data_samples]) + + img_latent = self.backbone(batch_inputs[0], mask) + + # batch_inputs[1] is the target image + with torch.no_grad(): + target = self.target_generator(batch_inputs[1]) + target = target.detach() + + if self.with_neck: + # BEiT v2 + feats, feats_cls_pt = self.neck( + img_latent, rel_pos_bias=self.backbone.shared_rel_pos_bias) + loss = self.head(feats, feats_cls_pt, target, mask) + else: + # BEiT v1 + loss = self.head(img_latent[0], target, mask) + + if isinstance(loss, torch.Tensor): + losses = dict(loss=loss) + return losses + elif isinstance(loss, Tuple): + # the loss_1 and loss_2 are general reconstruction loss (patch + # feature vectors from last layer of backbone) and early state + # reconstruction loss (patch feature vectors from intermediate + # layer of backbone) + loss_1, loss_2 = loss[0], loss[1] + losses = dict() + # the key with prefix 'loss', like loss_1 and loss_2, will be used + # as the final criterion + losses['loss_1'] = loss_1 + losses['loss_2'] = loss_2 + return losses diff --git a/mmselfsup/models/backbones/__init__.py b/mmselfsup/models/backbones/__init__.py index 17559f0f..f64b5f64 100644 --- a/mmselfsup/models/backbones/__init__.py +++ b/mmselfsup/models/backbones/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .beit_vit import BEiTViT from .cae_vit import CAEViT from .mae_vit import MAEViT from .maskfeat_vit import MaskFeatViT @@ -9,5 +10,5 @@ from .simmim_swin import SimMIMSwinTransformer __all__ = [ 'ResNet', 'ResNetSobel', 'ResNetV1d', 'ResNeXt', 'MAEViT', 'MoCoV3ViT', - 'SimMIMSwinTransformer', 'CAEViT', 'MaskFeatViT' + 'SimMIMSwinTransformer', 'CAEViT', 'MaskFeatViT', 'BEiTViT' ] diff --git a/mmselfsup/models/backbones/beit_vit.py b/mmselfsup/models/backbones/beit_vit.py new file mode 100644 index 00000000..d8755df4 --- /dev/null +++ b/mmselfsup/models/backbones/beit_vit.py @@ -0,0 +1,188 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Optional, Tuple, Union + +import torch +from mmcls.models import BEiT, resize_pos_embed +from mmengine.model.weight_init import trunc_normal_ +from torch import nn + +from mmselfsup.registry import MODELS + + +@MODELS.register_module() +class BEiTViT(BEiT): + """Vision Transformer for BEiT pre-training. + + Rewritten version of: `An Image is Worth 16x16 Words: Transformers + for Image Recognition at Scale `_ + + Args: + arch (str | dict): Vision Transformer architecture. If use string, + choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small' + and 'deit-base'. If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **num_layers** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + - **feedforward_channels** (int): The hidden dimensions in + feedforward modules. + + Defaults to 'base'. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to the most + common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + with_cls_token (bool): Whether concatenating class token into image + tokens as transformer input. Defaults to True. + avg_token (bool): Whether or not to use the mean patch token for + classification. If True, the model will only take the average + of all patch tokens. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + output_cls_token (bool): Whether output the cls_token. If set True, + ``with_cls_token`` must be True. Defaults to True. + use_abs_pos_emb (bool): Whether or not use absolute position embedding. + Defaults to False. + use_rel_pos_bias (bool): Whether or not use relative position bias. + Defaults to False. + use_shared_rel_pos_bias (bool): Whether or not use shared relative + position bias. Defaults to True. + layer_scale_init_value (float): The initialization value for + the learnable scaling of attention and FFN. Defaults to 0.1. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + arch: str = 'base', + img_size: int = 224, + patch_size: int = 16, + in_channels: int = 3, + out_indices: int = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + avg_token: bool = False, + frozen_stages: int = -1, + output_cls_token: bool = True, + use_abs_pos_emb: bool = False, + use_rel_pos_bias: bool = False, + use_shared_rel_pos_bias: bool = True, + layer_scale_init_value: int = 0.1, + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(padding=0), + layer_cfgs: dict = dict(), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + avg_token=avg_token, + frozen_stages=frozen_stages, + output_cls_token=output_cls_token, + use_abs_pos_emb=use_abs_pos_emb, + use_shared_rel_pos_bias=use_shared_rel_pos_bias, + use_rel_pos_bias=use_rel_pos_bias, + layer_scale_init_value=layer_scale_init_value, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + init_cfg=init_cfg) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) + + def init_weights(self) -> None: + """Initialize position embedding, patch embedding and cls token.""" + super().init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + trunc_normal_(self.cls_token, std=0.02) + trunc_normal_(self.mask_token, std=0.02) + self.rescale_init_weight() + + def rescale_init_weight(self) -> None: + """Rescale the initialized weights.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.layers): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.ffn.layers[1].weight.data, layer_id + 1) + + def forward(self, x: torch.Tensor, + mask: torch.Tensor) -> Tuple[torch.Tensor]: + """The BEiT style forward function. + + Args: + x (torch.Tensor): Input images, which is of shape (B x C x H x W). + mask (torch.Tensor): Mask for input, which is of shape + (B x patch_resolution[0] x patch_resolution[1]). + + Returns: + Tuple[torch.Tensor]: Hidden features. + """ + x, patch_resolution = self.patch_embed(x) + + # replace the masked visual tokens by mask_token + B, L, _ = x.shape + mask_token = self.mask_token.expand(B, L, -1) + w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) + x = x * (1. - w) + mask_token * w + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + resize_pos_embed( + self.pos_embed, + self.patch_resolution, + patch_resolution, + mode=self.interpolate_mode, + num_extra_tokens=self.num_extra_tokens) + x = self.drop_after_pos(x) + + self.shared_rel_pos_bias = self.rel_pos_bias().to( + mask.device) if self.rel_pos_bias is not None else None + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x, rel_pos_bias=self.shared_rel_pos_bias) + + if i == len(self.layers) - 1 and self.final_norm: + x = self.norm1(x) + + if i in self.out_indices: + outs.append(x) + + return tuple(outs) diff --git a/mmselfsup/models/heads/__init__.py b/mmselfsup/models/heads/__init__.py index 8f0026f7..86d82fb1 100644 --- a/mmselfsup/models/heads/__init__.py +++ b/mmselfsup/models/heads/__init__.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .beitv1_head import BEiTV1Head +from .beitv2_head import BEiTV2Head from .cae_head import CAEHead from .cls_head import ClsHead from .contrastive_head import ContrastiveHead @@ -11,7 +13,8 @@ from .simmim_head import SimMIMHead from .swav_head import SwAVHead __all__ = [ - 'ContrastiveHead', 'ClsHead', 'LatentPredictHead', - 'LatentCrossCorrelationHead', 'MultiClsHead', 'MAEPretrainHead', - 'MoCoV3Head', 'SimMIMHead', 'CAEHead', 'SwAVHead', 'MaskFeatPretrainHead' + 'BEiTV1Head', 'BEiTV2Head', 'ContrastiveHead', 'ClsHead', + 'LatentPredictHead', 'LatentCrossCorrelationHead', 'MultiClsHead', + 'MAEPretrainHead', 'MoCoV3Head', 'SimMIMHead', 'CAEHead', 'SwAVHead', + 'MaskFeatPretrainHead' ] diff --git a/mmselfsup/models/heads/beitv1_head.py b/mmselfsup/models/heads/beitv1_head.py new file mode 100644 index 00000000..0eebc597 --- /dev/null +++ b/mmselfsup/models/heads/beitv1_head.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmselfsup.registry import MODELS + + +@MODELS.register_module() +class BEiTV1Head(BaseModule): + """Pretrain Head for BEiT v1. + + Compute the logits and the cross entropy loss. + + Args: + embed_dims (int): The dimension of embedding. + num_embed (int): The number of classification types. + loss (dict): The config of loss. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + embed_dims: int, + num_embed: int, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = dict( + type='TruncNormal', layer='Linear', std=0.02, bias=0) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.cls_head = nn.Linear(embed_dims, num_embed) + self.loss = MODELS.build(loss) + + def forward(self, feats: torch.Tensor, target: torch.Tensor, + mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + feats (torch.Tensor): Features from backbone. + target (torch.Tensor): Target generated by target_generator. + mask (torch.Tensor): Generated mask for pretraing. + """ + mask = mask.flatten(1).to(torch.bool) + target = torch.argmax(target, dim=1).flatten(1) + target = target[mask] + + # remove cls_token + feats = feats[:, 1:] + logits = self.cls_head(feats[mask]) + + loss = self.loss(logits, target) + return loss diff --git a/mmselfsup/models/heads/beitv2_head.py b/mmselfsup/models/heads/beitv2_head.py new file mode 100644 index 00000000..d77cf63a --- /dev/null +++ b/mmselfsup/models/heads/beitv2_head.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmselfsup.registry import MODELS + + +@MODELS.register_module() +class BEiTV2Head(BaseModule): + """Pretrain Head for BEiT. + + Compute the logits and the cross entropy loss. + + Args: + embed_dims (int): The dimension of embedding. + num_embed (int): The number of classification types. + loss (dict): The config of loss. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + embed_dims: int, + num_embed: int, + loss: dict, + init_cfg: Optional[Union[dict, List[dict]]] = dict( + type='TruncNormal', layer='Linear', std=0.02, bias=0) + ) -> None: + super().__init__(init_cfg=init_cfg) + self.cls_head = nn.Linear(embed_dims, num_embed) + self.loss = MODELS.build(loss) + + def forward(self, feats: torch.Tensor, feats_cls_pt: torch.Tensor, + target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Generate loss. + + Args: + feats (torch.Tensor): Features from backbone. + feats_cls_pt (torch.Tensor) : Features from class late layers for + pretraining. + target (torch.Tensor): Target generated by target_generator. + mask (torch.Tensor): Generated mask for pretraing. + """ + mask = mask.flatten(1).to(torch.bool) + target = target[mask] + + # shared cls head + logits = self.cls_head(feats[mask]) + logits_cls_pt = self.cls_head(feats_cls_pt[mask]) + + loss = self.loss((logits, logits_cls_pt), target) + return loss diff --git a/mmselfsup/models/losses/__init__.py b/mmselfsup/models/losses/__init__.py index 752a50db..f0754716 100644 --- a/mmselfsup/models/losses/__init__.py +++ b/mmselfsup/models/losses/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .beit_loss import BEiTLoss from .cae_loss import CAELoss from .cosine_similarity_loss import CosineSimilarityLoss from .cross_correlation_loss import CrossCorrelationLoss @@ -8,7 +9,7 @@ from .simmim_loss import SimMIMReconstructionLoss from .swav_loss import SwAVLoss __all__ = [ - 'CAELoss', 'CrossCorrelationLoss', 'CosineSimilarityLoss', + 'BEiTLoss', 'CAELoss', 'CrossCorrelationLoss', 'CosineSimilarityLoss', 'MAEReconstructionLoss', 'SimMIMReconstructionLoss', 'SwAVLoss', 'PixelReconstructionLoss' ] diff --git a/mmselfsup/models/losses/beit_loss.py b/mmselfsup/models/losses/beit_loss.py new file mode 100644 index 00000000..84178929 --- /dev/null +++ b/mmselfsup/models/losses/beit_loss.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, Union + +import torch +from mmengine.model import BaseModule +from torch import nn + +from mmselfsup.registry import MODELS + + +@MODELS.register_module() +class BEiTLoss(BaseModule): + """Loss function for BEiT. + + The BEiTLoss supports 2 diffenrent logits shared 1 target, like BEiT v2. + """ + + def __init__(self) -> None: + super().__init__() + self.loss_cross_entropy = nn.CrossEntropyLoss() + + def forward(self, logits: Union[Tuple[torch.Tensor], torch.Tensor], + target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward function of BEiT Loss. + + Args: + logits (torch.Tensor): The outputs from the decoder. + target (torch.Tensor): The targets generated by dalle. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The main loss. + """ + if isinstance(logits, torch.Tensor): + loss = self.loss_cross_entropy(logits, target) + return loss + elif isinstance(logits, Tuple): + loss_1 = self.loss_cross_entropy(logits[0], target) + loss_2 = self.loss_cross_entropy(logits[1], target) + return loss_1, loss_2 diff --git a/mmselfsup/models/necks/__init__.py b/mmselfsup/models/necks/__init__.py index 19cec7d3..ab5faad5 100644 --- a/mmselfsup/models/necks/__init__.py +++ b/mmselfsup/models/necks/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .avgpool2d_neck import AvgPool2dNeck +from .beitv2_neck import BEiTV2Neck from .cae_neck import CAENeck from .densecl_neck import DenseCLNeck from .linear_neck import LinearNeck @@ -12,7 +13,7 @@ from .simmim_neck import SimMIMNeck from .swav_neck import SwAVNeck __all__ = [ - 'AvgPool2dNeck', 'DenseCLNeck', 'LinearNeck', 'MoCoV2Neck', + 'AvgPool2dNeck', 'BEiTV2Neck', 'DenseCLNeck', 'LinearNeck', 'MoCoV2Neck', 'NonLinearNeck', 'ODCNeck', 'RelativeLocNeck', 'SwAVNeck', 'MAEPretrainDecoder', 'SimMIMNeck', 'CAENeck', 'ClsBatchNormNeck' ] diff --git a/mmselfsup/models/necks/beitv2_neck.py b/mmselfsup/models/necks/beitv2_neck.py new file mode 100644 index 00000000..4f65f050 --- /dev/null +++ b/mmselfsup/models/necks/beitv2_neck.py @@ -0,0 +1,153 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from mmcls.models.backbones.beit import BEiTTransformerEncoderLayer +from mmcv.cnn import build_norm_layer +from mmengine.model import BaseModule + +from mmselfsup.registry import MODELS + + +@MODELS.register_module() +class BEiTV2Neck(BaseModule): + """Neck for BEiTV2 Pre-training. + + This module construct the decoder for the final prediction. + + Args: + num_layers (int): Number of encoder layers of neck. Defaults to 2. + early_layers (int): The layer index of the early output from the + backbone. Defaults to 9. + backbone_arch (str): Vision Transformer architecture. Defaults to base. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): The initialization value for the + learnable scaling of attention and FFN. Defaults to 0.1. + use_rel_pos_bias (bool): Whether to use unique relative position bias, + if False, use shared relative position bias defined in backbone. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'depth': 12, + 'num_heads': 12, + 'feedforward_channels': 3072, + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'depth': 24, + 'num_heads': 16, + 'feedforward_channels': 4096, + }), + } + + def __init__( + self, + num_layers: int = 2, + early_layers: int = 9, + backbone_arch: str = 'base', + drop_rate: float = 0., + drop_path_rate: float = 0., + layer_scale_init_value: float = 0.1, + use_rel_pos_bias: bool = False, + norm_cfg: dict = dict(type='LN', eps=1e-6), + init_cfg: Optional[Union[dict, List[dict]]] = dict( + type='TruncNormal', layer='Linear', std=0.02, bias=0) + ) -> None: + super().__init__(init_cfg=init_cfg) + + if isinstance(backbone_arch, str): + backbone_arch = backbone_arch.lower() + assert backbone_arch in set(self.arch_zoo), \ + (f'Arch {backbone_arch} is not in default archs ' + f'{set(self.arch_zoo)}') + self.arch_settings = self.arch_zoo[backbone_arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(backbone_arch, dict) and essential_keys <= set( + backbone_arch + ), f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = backbone_arch + + # stochastic depth decay rule + self.early_layers = early_layers + depth = self.arch_settings['depth'] + dpr = np.linspace(0, drop_path_rate, + max(depth, early_layers + num_layers)) + + self.patch_aggregation = nn.ModuleList() + for i in range(early_layers, early_layers + num_layers): + _layer_cfg = dict( + embed_dims=self.arch_settings['embed_dims'], + num_heads=self.arch_settings['num_heads'], + feedforward_channels=self. + arch_settings['feedforward_channels'], + drop_rate=drop_rate, + drop_path_rate=dpr[i], + norm_cfg=norm_cfg, + layer_scale_init_value=layer_scale_init_value, + window_size=None, + use_rel_pos_bias=use_rel_pos_bias) + self.patch_aggregation.append( + BEiTTransformerEncoderLayer(**_layer_cfg)) + + self.rescale_patch_aggregation_init_weight() + + embed_dims = self.arch_settings['embed_dims'] + _, norm = build_norm_layer(norm_cfg, embed_dims) + self.add_module('norm', norm) + + def rescale_patch_aggregation_init_weight(self): + """Rescale the initialized weights.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.patch_aggregation): + rescale(layer.attn.proj.weight.data, + self.early_layers + layer_id + 1) + rescale(layer.ffn.layers[1].weight.data, + self.early_layers + layer_id + 1) + + def forward(self, inputs: Tuple[torch.Tensor], rel_pos_bias: torch.Tensor, + **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the latent prediction and final prediction. + + Args: + x (Tuple[torch.Tensor]): Features of tokens. + rel_pos_bias (torch.Tensor): Shared relative position bias table. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - ``x``: The final layer features from backbone, which are normed + in ``BEiTV2Neck``. + - ``x_cls_pt``: The early state features from backbone, which are + consist of final layer cls_token and early state patch_tokens + from backbone and sent to PatchAggregation layers in the neck. + """ + + early_states, x = inputs[0], inputs[1] + x_cls_pt = torch.cat([x[:, [0]], early_states[:, 1:]], dim=1) + for layer in self.patch_aggregation: + x_cls_pt = layer(x_cls_pt, rel_pos_bias=rel_pos_bias) + + # shared norm + x, x_cls_pt = self.norm(x), self.norm(x_cls_pt) + + # remove cls_token + x = x[:, 1:] + x_cls_pt = x_cls_pt[:, 1:] + return x, x_cls_pt diff --git a/mmselfsup/models/target_generators/__init__.py b/mmselfsup/models/target_generators/__init__.py index 4e61ce15..1876d59e 100644 --- a/mmselfsup/models/target_generators/__init__.py +++ b/mmselfsup/models/target_generators/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .dall_e import Encoder from .hog_generator import HOGGenerator +from .vqkd import VQKD -__all__ = [ - 'HOGGenerator', -] +__all__ = ['HOGGenerator', 'VQKD', 'Encoder'] diff --git a/mmselfsup/models/target_generators/dall_e.py b/mmselfsup/models/target_generators/dall_e.py new file mode 100644 index 00000000..631617f1 --- /dev/null +++ b/mmselfsup/models/target_generators/dall_e.py @@ -0,0 +1,180 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from BEiT +# https://github.com/microsoft/unilm/blob/master/beit/dall_e/encoder.py +import math +from collections import OrderedDict +from functools import partial +from typing import List, Optional, Union + +import attr +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmselfsup.registry import MODELS + + +@attr.s(eq=False) +class Conv2d(nn.Module): + n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) + n_out: int = attr.ib(validator=lambda i, a, x: x >= 1) + kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1) + + use_float16: bool = attr.ib(default=True) + device: torch.device = attr.ib(default=torch.device('cpu')) + requires_grad: bool = attr.ib(default=False) + + def __attrs_post_init__(self) -> None: + super().__init__() + + w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), + dtype=torch.float32, + device=self.device, + requires_grad=self.requires_grad) + w.normal_(std=1 / math.sqrt(self.n_in * self.kw**2)) + + b = torch.zeros((self.n_out, ), + dtype=torch.float32, + device=self.device, + requires_grad=self.requires_grad) + self.w, self.b = nn.Parameter(w), nn.Parameter(b) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_float16 and 'cuda' in self.w.device.type: + if x.dtype != torch.float16: + x = x.half() + + w, b = self.w.half(), self.b.half() + else: + if x.dtype != torch.float32: + x = x.float() + + w, b = self.w, self.b + + return F.conv2d(x, w, b, padding=(self.kw - 1) // 2) + + +@attr.s(eq=False, repr=False) +class EncoderBlock(nn.Module): + n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) + n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 == 0) + n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1) + + device: torch.device = attr.ib(default=None) + requires_grad: bool = attr.ib(default=False) + + def __attrs_post_init__(self) -> None: + super().__init__() + self.n_hid = self.n_out // 4 + self.post_gain = 1 / (self.n_layers**2) + + make_conv = partial( + Conv2d, device=self.device, requires_grad=self.requires_grad) + self.id_path = make_conv( + self.n_in, self.n_out, + 1) if self.n_in != self.n_out else nn.Identity() + self.res_path = nn.Sequential( + OrderedDict([ + ('relu_1', nn.ReLU()), + ('conv_1', make_conv(self.n_in, self.n_hid, 3)), + ('relu_2', nn.ReLU()), + ('conv_2', make_conv(self.n_hid, self.n_hid, 3)), + ('relu_3', nn.ReLU()), + ('conv_3', make_conv(self.n_hid, self.n_hid, 3)), + ('relu_4', nn.ReLU()), + ('conv_4', make_conv(self.n_hid, self.n_out, 1)), + ])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.id_path(x) + self.post_gain * self.res_path(x) + + +@attr.s(eq=False, repr=False) +@MODELS.register_module(name='DALL-E') +class Encoder(BaseModule): + group_count: int = 4 + n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) + n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) + input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1) + vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512) + + device: torch.device = attr.ib(default=torch.device('cpu')) + requires_grad: bool = attr.ib(default=False) + use_mixed_precision: bool = attr.ib(default=True) + init_cfg: Optional[Union[dict, List[dict]]] = attr.ib(default=None) + + def __attrs_post_init__(self) -> None: + super().__init__(init_cfg=self.init_cfg) + + blk_range = range(self.n_blk_per_group) + n_layers = self.group_count * self.n_blk_per_group + make_conv = partial( + Conv2d, device=self.device, requires_grad=self.requires_grad) + make_blk = partial( + EncoderBlock, + n_layers=n_layers, + device=self.device, + requires_grad=self.requires_grad) + + self.blocks = nn.Sequential( + OrderedDict([ + ('input', make_conv(self.input_channels, 1 * self.n_hid, 7)), + ('group_1', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk(1 * self.n_hid, 1 * self.n_hid)) + for i in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]))), + ('group_2', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk( + 1 * self.n_hid if i == 0 else 2 * self.n_hid, + 2 * self.n_hid)) for i in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]))), + ('group_3', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk( + 2 * self.n_hid if i == 0 else 4 * self.n_hid, + 4 * self.n_hid)) for i in blk_range], + ('pool', nn.MaxPool2d(kernel_size=2)), + ]))), + ('group_4', + nn.Sequential( + OrderedDict([ + *[(f'block_{i + 1}', + make_blk( + 4 * self.n_hid if i == 0 else 8 * self.n_hid, + 8 * self.n_hid)) for i in blk_range], + ]))), + ('output', + nn.Sequential( + OrderedDict([ + ('relu', nn.ReLU()), + ('conv', + make_conv( + 8 * self.n_hid, + self.vocab_size, + 1, + use_float16=False)), + ]))), + ])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.float() + if len(x.shape) != 4: + raise ValueError(f'input shape {x.shape} is not 4d') + if x.shape[1] != self.input_channels: + raise ValueError(f'input has {x.shape[1]} channels but model \ + built for {self.input_channels}') + if x.dtype != torch.float32: + raise ValueError('input must have dtype torch.float32') + + return self.blocks(x) diff --git a/mmselfsup/models/target_generators/vqkd.py b/mmselfsup/models/target_generators/vqkd.py new file mode 100644 index 00000000..c8c77af3 --- /dev/null +++ b/mmselfsup/models/target_generators/vqkd.py @@ -0,0 +1,104 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional, Tuple + +import torch +from einops import rearrange +from mmcls.models import BEiT +from mmengine.model import BaseModule +from torch import nn + +from mmselfsup.models.utils import NormEMAVectorQuantizer +from mmselfsup.registry import MODELS + + +@MODELS.register_module() +class VQKD(BaseModule): + """Vector-Quantized Knowledge Distillation. + + The module only contains encoder and VectorQuantizer part + Modified from https://github.com/microsoft/unilm/blob/master/beit2/modeling_vqkd.py + + Args: + encoder_config (dict): The config of encoder. + decoder_config (dict, optional): The config of decoder. Currently, + VQKD only support to build encoder. Defaults to None. + num_embed (int): Number of embedding vectors in the codebook. Defaults + to 8192. + embed_dims (int) : The dimension of embedding vectors in the codebook. + Defaults to 32. + decay (float): The decay parameter of EMA. Defaults to 0.99. + beta (float): The mutiplier for VectorQuantizer loss. Defaults to 1. + quantize_kmeans_init (bool): Whether to use k-means to initialize the + VectorQuantizer. Defaults to True. + init_cfg (dict or List[dict], optional): Initialization config dict. + Defaults to None. + """ # noqa: E501 + + def __init__(self, + encoder_config: dict, + decoder_config: Optional[dict] = None, + num_embed: int = 8192, + embed_dims: int = 32, + decay: float = 0.99, + beta: float = 1.0, + quantize_kmeans_init: bool = True, + init_cfg: Optional[dict] = None) -> None: + super().__init__(init_cfg=init_cfg) + + self.encoder = BEiT(**encoder_config) + if decoder_config is not None: + self.decoder = BEiT(**decoder_config) + + self.quantize = NormEMAVectorQuantizer( + num_embed=num_embed, + embed_dims=embed_dims, + beta=beta, + decay=decay, + kmeans_init=quantize_kmeans_init, + ) + + # task layer + self.encode_task_layer = nn.Sequential( + nn.Linear(self.encoder.arch_settings['embed_dims'], + self.encoder.arch_settings['embed_dims']), nn.Tanh(), + nn.Linear(self.encoder.arch_settings['embed_dims'], embed_dims)) + + def get_tokens(self, x: torch.Tensor) -> dict: + """Get tokens for beit pre-training.""" + _, embed_ind, _ = self.encode(x) + output = {} + output['token'] = embed_ind.view(x.shape[0], -1) + output['input_img'] = x + + return output + + def encode( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode the input images and get corresponding results.""" + encoder_features = self.encoder(x)[0] + B, C, N1, N2 = encoder_features.shape + encoder_features = encoder_features.permute(0, 2, 3, + 1).reshape(B, N1 * N2, C) + + with torch.cuda.amp.autocast(enabled=False): + to_quantizer_features = self.encode_task_layer( + encoder_features.type_as(self.encode_task_layer[-1].weight)) + + N = to_quantizer_features.shape[1] + h, w = int(math.sqrt(N)), int(math.sqrt(N)) + + to_quantizer_features = rearrange( + to_quantizer_features, 'b (h w) c -> b c h w', h=h, + w=w) # reshape for quantizer + quantize, loss, embed_ind = self.quantize(to_quantizer_features) + + return quantize, embed_ind, loss + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """The forward function. + + Currently, only support to get tokens. + """ + return self.get_tokens(x)['token'] diff --git a/mmselfsup/models/utils/__init__.py b/mmselfsup/models/utils/__init__.py index 284251da..0deff280 100644 --- a/mmselfsup/models/utils/__init__.py +++ b/mmselfsup/models/utils/__init__.py @@ -3,7 +3,8 @@ from .dall_e import Encoder from .data_preprocessor import (CAEDataPreprocessor, RelativeLocDataPreprocessor, RotationPredDataPreprocessor, - SelfSupDataPreprocessor) + SelfSupDataPreprocessor, + TwoNormDataPreprocessor) from .ema import CosineEMA from .extractor import Extractor from .gather_layer import GatherLayer @@ -13,6 +14,7 @@ from .position_embedding import build_2d_sincos_position_embedding from .sobel import Sobel from .transformer_blocks import (CAETransformerRegressorLayer, MultiheadAttention, TransformerEncoderLayer) +from .vector_quantizer import NormEMAVectorQuantizer try: from .res_layer_extra_norm import ResLayerExtraNorm @@ -20,9 +22,22 @@ except ImportError: ResLayerExtraNorm = None __all__ = [ - 'Extractor', 'GatherLayer', 'MultiPooling', 'MultiPrototypes', - 'build_2d_sincos_position_embedding', 'Sobel', 'MultiheadAttention', - 'TransformerEncoderLayer', 'CAETransformerRegressorLayer', 'Encoder', - 'CosineEMA', 'SelfSupDataPreprocessor', 'RelativeLocDataPreprocessor', - 'RotationPredDataPreprocessor', 'CAEDataPreprocessor', 'ResLayerExtraNorm' + 'Extractor', + 'GatherLayer', + 'MultiPooling', + 'MultiPrototypes', + 'build_2d_sincos_position_embedding', + 'Sobel', + 'MultiheadAttention', + 'TransformerEncoderLayer', + 'CAETransformerRegressorLayer', + 'Encoder', + 'CosineEMA', + 'SelfSupDataPreprocessor', + 'RelativeLocDataPreprocessor', + 'RotationPredDataPreprocessor', + 'CAEDataPreprocessor', + 'ResLayerExtraNorm', + 'NormEMAVectorQuantizer', + 'TwoNormDataPreprocessor', ] diff --git a/mmselfsup/models/utils/data_preprocessor.py b/mmselfsup/models/utils/data_preprocessor.py index c2786c72..4901cd3d 100644 --- a/mmselfsup/models/utils/data_preprocessor.py +++ b/mmselfsup/models/utils/data_preprocessor.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple, Union import torch from mmengine.model import ImgDataPreprocessor @@ -179,3 +179,114 @@ class CAEDataPreprocessor(SelfSupDataPreprocessor): batch_inputs[1] / 255. * 0.8 + 0.1] return batch_inputs, batch_data_samples + + +@MODELS.register_module() +class TwoNormDataPreprocessor(SelfSupDataPreprocessor): + """Image pre-processor for CAE, BEiT v1/v2, etc. + + Compared with the :class:`mmselfsup.SelfSupDataPreprocessor`, this module + will normalize the prediction image and target image with different + normalization parameters. + + Args: + mean (Sequence[float or int], optional): The pixel mean of image + channels. If ``bgr_to_rgb=True`` it means the mean value of R, + G, B channels. If the length of `mean` is 1, it means all + channels have the same mean value, or the input is a gray image. + If it is not specified, images will not be normalized. Defaults + None. + std (Sequence[float or int], optional): The pixel standard deviation of + image channels. If ``bgr_to_rgb=True`` it means the standard + deviation of R, G, B channels. If the length of `std` is 1, + it means all channels have the same standard deviation, or the + input is a gray image. If it is not specified, images will + not be normalized. Defaults None. + second_mean (Sequence[float or int], optional): The description is + like ``mean``, it can be customized for targe image. Defaults None. + second_std (Sequence[float or int], optional): The description is + like ``std``, it can be customized for targe image. Defaults None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (float or int): The padded pixel value. Defaults to 0. + bgr_to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + rgb_to_bgr (bool): whether to convert image from RGB to RGB. + Defaults to False. + non_blocking (bool): Whether block current process + when transferring data to device. + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + second_mean: Sequence[Union[float, int]] = None, + second_std: Sequence[Union[float, int]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + non_blocking: Optional[bool] = False): + super().__init__( + mean=mean, + std=std, + pad_size_divisor=pad_size_divisor, + pad_value=pad_value, + bgr_to_rgb=bgr_to_rgb, + rgb_to_bgr=rgb_to_bgr, + non_blocking=non_blocking) + assert (second_mean is not None) and (second_std is not None), ( + 'mean and std should not be None while using ' + '`TwoNormDataPreprocessor`') + assert len(second_mean) == 3 or len(second_mean) == 1, ( + '`mean` should have 1 or 3 values, to be compatible with ' + f'RGB or gray image, but got {len(second_mean)} values') + assert len(second_std) == 3 or len(second_std) == 1, ( + '`std` should have 1 or 3 values, to be compatible with RGB ' # type: ignore # noqa: E501 + f'or gray image, but got {len(std)} values') # type: ignore + + self.register_buffer('second_mean', + torch.tensor(second_mean).view(-1, 1, 1), False) + self.register_buffer('second_std', + torch.tensor(second_std).view(-1, 1, 1), False) + + def forward( + self, + data: dict, + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[torch.Tensor, Optional[list]]: Data in the same format as the + model input. + """ + data = [val for _, val in data.items()] + batch_inputs, batch_data_samples = self.cast_data(data) + # channel transform + if self._channel_conversion: + batch_inputs = [ + _input[:, [2, 1, 0], ...] for _input in batch_inputs + ] + + # Convert to float after channel conversion to ensure + # efficiency + batch_inputs = [input_.float() for input_ in batch_inputs] + + # Normalization. Here is what is different from + # :class:`mmselfsup.SelfSupDataPreprocessor`. Normalize the target + # image and prediction image with different normalization params + if self._enable_normalize: + batch_inputs = [ + (batch_inputs[0] - self.mean) / self.std, + (batch_inputs[1] - self.second_mean) / self.second_std + ] + + return batch_inputs, batch_data_samples diff --git a/mmselfsup/models/utils/vector_quantizer.py b/mmselfsup/models/utils/vector_quantizer.py new file mode 100644 index 00000000..7c2ea893 --- /dev/null +++ b/mmselfsup/models/utils/vector_quantizer.py @@ -0,0 +1,232 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) 2022 Microsoft +# Modified from +# https://github.com/microsoft/unilm/blob/master/beit2/norm_ema_quantizer.py +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from mmengine.dist import all_reduce + + +def ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor, + decay: torch.Tensor) -> None: + """Update moving average.""" + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def norm_ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor, + decay: torch.Tensor) -> None: + """Update moving average with norm data.""" + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + moving_avg.data.copy_(F.normalize(moving_avg.data, p=2, dim=-1)) + + +def sample_vectors(samples: torch.Tensor, num: int) -> torch.Tensor: + """Sample vectors according to the given number.""" + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num, ), device=device) + + return samples[indices] + + +def kmeans(samples: torch.Tensor, + num_clusters: int, + num_iters: int = 10, + use_cosine_sim: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + """Run k-means algorithm.""" + dim, dtype, _ = samples.shape[-1], samples.dtype, samples.device + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + if use_cosine_sim: + dists = samples @ means.t() + else: + diffs = rearrange(samples, 'n d -> n () d') \ + - rearrange(means, 'c d -> () c d') + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + if use_cosine_sim: + new_means = F.normalize(new_means, p=2, dim=-1) + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EmbeddingEMA(nn.Module): + """The codebook of embedding vectors. + + Args: + num_tokens (int): Number of embedding vectors in the codebook. + codebook_dim (int) : The dimension of embedding vectors in the + codebook. + kmeans_init (bool): Whether to use k-means to initialize the + VectorQuantizer. Defaults to True. + codebook_init_path (str): The initialization checkpoint for codebook. + Defaults to None. + """ + + def __init__(self, + num_tokens: int, + codebook_dim: int, + kmeans_init: bool = True, + codebook_init_path: Optional[str] = None): + super().__init__() + self.num_tokens = num_tokens + self.codebook_dim = codebook_dim + if codebook_init_path is None: + if not kmeans_init: + weight = torch.randn(num_tokens, codebook_dim) + weight = F.normalize(weight, p=2, dim=-1) + else: + weight = torch.zeros(num_tokens, codebook_dim) + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + else: + print(f'load init codebook weight from {codebook_init_path}') + codebook_ckpt_weight = torch.load( + codebook_init_path, map_location='cpu') + weight = codebook_ckpt_weight.clone() + self.register_buffer('initted', torch.Tensor([True])) + + self.weight = nn.Parameter(weight, requires_grad=False) + self.update = True + + @torch.jit.ignore + def init_embed_(self, data: torch.Tensor) -> None: + """Initialize embedding vectors of codebook.""" + if self.initted: + return + print('Performing K-means init for codebook') + embed, _ = kmeans(data, self.num_tokens, 10, use_cosine_sim=True) + self.weight.data.copy_(embed) + self.initted.data.copy_(torch.Tensor([True])) + + def forward(self, embed_id: torch.Tensor) -> torch.Tensor: + """Get embedding vectors.""" + return F.embedding(embed_id, self.weight) + + +class NormEMAVectorQuantizer(nn.Module): + """Normed EMA vector quantizer module. + + Args: + num_embed (int): Number of embedding vectors in the codebook. Defaults + to 8192. + embed_dims (int) : The dimension of embedding vectors in the codebook. + Defaults to 32. + beta (float): The mutiplier for VectorQuantizer embedding loss. + Defaults to 1. + decay (float): The decay parameter of EMA. Defaults to 0.99. + statistic_code_usage (bool): Whether to use cluster_size to record + statistic. Defaults to True. + kmeans_init (bool): Whether to use k-means to initialize the + VectorQuantizer. Defaults to True. + codebook_init_path (str): The initialization checkpoint for codebook. + Defaults to None. + """ + + def __init__(self, + num_embed: int, + embed_dims: int, + beta: float, + decay: float = 0.99, + statistic_code_usage: bool = True, + kmeans_init: bool = True, + codebook_init_path: Optional[str] = None) -> None: + super().__init__() + self.codebook_dim = embed_dims + self.num_tokens = num_embed + self.beta = beta + self.decay = decay + + # learnable = True if orthogonal_reg_weight > 0 else False + self.embedding = EmbeddingEMA( + num_tokens=self.num_tokens, + codebook_dim=self.codebook_dim, + kmeans_init=kmeans_init, + codebook_init_path=codebook_init_path) + + self.statistic_code_usage = statistic_code_usage + if statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(num_embed)) + + def reset_cluster_size(self, device): + + if self.statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) + self.cluster_size = self.cluster_size.to(device) + + def forward(self, z): + """Forward function.""" + # reshape z -> (batch, height, width, channel) + z = rearrange(z, 'b c h w -> b h w c') + z = F.normalize(z, p=2, dim=-1) + z_flattened = z.reshape(-1, self.codebook_dim) + + self.embedding.init_embed_(z_flattened) + + # 'n d -> d n' + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + + if not self.training: + with torch.no_grad(): + cluster_size = encodings.sum(0) + all_reduce(cluster_size) + ema_inplace(self.cluster_size, cluster_size, self.decay) + + if self.training and self.embedding.update: + # update cluster size with EMA + bins = encodings.sum(0) + all_reduce(bins) + ema_inplace(self.cluster_size, bins, self.decay) + + zero_mask = (bins == 0) + bins = bins.masked_fill(zero_mask, 1.) + + embed_sum = z_flattened.t() @ encodings + all_reduce(embed_sum) + + embed_normalized = (embed_sum / bins.unsqueeze(0)).t() + embed_normalized = F.normalize(embed_normalized, p=2, dim=-1) + embed_normalized = torch.where(zero_mask[..., None], + self.embedding.weight, + embed_normalized) + + # Update embedding vectors with EMA + norm_ema_inplace(self.embedding.weight, embed_normalized, + self.decay) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, encoding_indices diff --git a/model-index.yml b/model-index.yml index 81897fe9..5856a2b4 100644 --- a/model-index.yml +++ b/model-index.yml @@ -17,3 +17,4 @@ Import: - configs/selfsup/barlowtwins/metafile.yml - configs/selfsup/cae/metafile.yml - configs/selfsup/maskfeat/metafile.yml + - configs/selfsup/beit/metafile.yml diff --git a/tests/test_engine/test_optimizers/test_layer_decay_optim_wrapper_constructor.py b/tests/test_engine/test_optimizers/test_layer_decay_optim_wrapper_constructor.py index 5d6a40da..505a129e 100644 --- a/tests/test_engine/test_optimizers/test_layer_decay_optim_wrapper_constructor.py +++ b/tests/test_engine/test_optimizers/test_layer_decay_optim_wrapper_constructor.py @@ -12,7 +12,7 @@ class ToyViTBackbone(nn.Module): def __init__(self): super().__init__() self.cls_token = nn.Parameter(torch.ones(1)) - self.patch_embed = nn.Parameter(torch.ones(1)) + self.pos_embed = nn.Parameter(torch.ones(1)) self.layers = nn.ModuleList() for _ in range(2): layer = nn.Conv2d(3, 3, 1) @@ -87,32 +87,47 @@ def test_learning_rate_decay_optimizer_wrapper_constructor(): weight_decay=base_wd, model_type='vit', layer_decay_rate=2.0)) + paramwise_cfg = dict( + custom_keys={ + '.bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0), + }) # test when model_type is None with pytest.raises(AssertionError): optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor( # noqa - optim_wrapper_cfg=optim_wrapper_cfg) + optim_wrapper_cfg=optim_wrapper_cfg, + paramwise_cfg=paramwise_cfg) optim_wrapper_cfg['optimizer']['model_type'] = None optimizer_wrapper = optimizer_wrapper_constructor(model) # test when model_type is invalid with pytest.raises(AssertionError): optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor( # noqa - optim_wrapper_cfg=optim_wrapper_cfg) + optim_wrapper_cfg=optim_wrapper_cfg, + paramwise_cfg=paramwise_cfg) optim_wrapper_cfg['optimizer']['model_type'] = 'invalid' optimizer_wrapper = optimizer_wrapper_constructor(model) # test vit optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor( - optim_wrapper_cfg=optim_wrapper_cfg) + optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg) optim_wrapper_cfg['optimizer']['model_type'] = 'vit' optimizer_wrapper = optimizer_wrapper_constructor(model) check_optimizer_lr_wd(optimizer_wrapper, expected_layer_wise_wd_lr_vit) # test swin + paramwise_cfg = dict( + custom_keys={ + '.norm': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + '.absolute_pos_embed': dict(decay_mult=0.0), + '.relative_position_bias_table': dict(decay_mult=0.0) + }) model = ToySwin() optimizer_wrapper_constructor = LearningRateDecayOptimWrapperConstructor( - optim_wrapper_cfg=optim_wrapper_cfg) + optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg) optim_wrapper_cfg['optimizer']['model_type'] = 'swin' optimizer_wrapper = optimizer_wrapper_constructor(model) assert optimizer_wrapper.optimizer.param_groups[-1]['lr_scale'] == 1.0 diff --git a/tests/test_models/test_algorithms/test_beitv1.py b/tests/test_models/test_algorithms/test_beitv1.py new file mode 100644 index 00000000..05a28338 --- /dev/null +++ b/tests/test_models/test_algorithms/test_beitv1.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest +import torch +from mmengine.structures import InstanceData + +from mmselfsup.models import BEiT +from mmselfsup.structures import SelfSupDataSample +from mmselfsup.utils import register_all_modules + +data_preprocessor = dict( + type='TwoNormDataPreprocessor', + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + second_mean=(-20.4, -20.4, -20.4), + second_std=(204., 204., 204.), + bgr_to_rgb=True) + +# model settings +backbone = dict( + type='BEiTViT', + arch='base', + patch_size=16, + drop_path_rate=0.1, + final_norm=True, + layer_scale_init_value=0.1, +) +neck = None +head = dict( + type='BEiTV1Head', + embed_dims=768, + num_embed=8192, + loss=dict(type='BEiTLoss')) +target_generator = dict(type='DALL-E') + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_beitv1(): + register_all_modules() + + model = BEiT( + backbone=backbone, + neck=neck, + head=head, + target_generator=target_generator, + data_preprocessor=data_preprocessor) + + fake_img = torch.rand((1, 3, 224, 224)) + fake_target_img = torch.rand((1, 3, 112, 112)) + fake_mask = torch.zeros((196)).bool() + fake_mask[75:150] = 1 + fake_data_sample = SelfSupDataSample() + fake_mask = InstanceData(value=fake_mask) + fake_data_sample.mask = fake_mask + fake_data_sample = [fake_data_sample] + + fake_data = { + 'inputs': [fake_img, fake_target_img], + 'data_sample': fake_data_sample + } + + fake_batch_inputs, fake_data_samples = model.data_preprocessor(fake_data) + fake_outputs = model(fake_batch_inputs, fake_data_samples, mode='loss') + assert isinstance(fake_outputs['loss'].item(), float) diff --git a/tests/test_models/test_algorithms/test_beitv2.py b/tests/test_models/test_algorithms/test_beitv2.py new file mode 100644 index 00000000..cd13bc24 --- /dev/null +++ b/tests/test_models/test_algorithms/test_beitv2.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest +import torch +from mmengine.structures import InstanceData + +from mmselfsup.models import BEiT +from mmselfsup.structures import SelfSupDataSample +from mmselfsup.utils import register_all_modules + +data_preprocessor = dict( + type='TwoNormDataPreprocessor', + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + second_mean=(127.5, 127.5, 127.5), + second_std=(127.5, 127.5, 127.5), + bgr_to_rgb=True) + +# model settings +vqkd_encoder = dict( + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + with_cls_token=True, + avg_token=False, + frozen_stages=-1, + output_cls_token=False, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + layer_scale_init_value=0., + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None) + +layer_scale_init_value = 0.1 +drop_path_rate = 0. # 0. for 300 epochs and 0.1 for 1600 epochs. +backbone = dict( + type='BEiTViT', + arch='base', + patch_size=16, + out_indices=[-4, -1], + drop_path_rate=drop_path_rate, + final_norm=False, + layer_scale_init_value=layer_scale_init_value, +) +neck = dict( + type='BEiTV2Neck', + num_layers=1, + early_layers=9, + backbone_arch='base', + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, +) +head = dict( + type='BEiTV2Head', + embed_dims=768, + num_embed=8192, + loss=dict(type='BEiTLoss')) +target_generator = dict(type='VQKD', encoder_config=vqkd_encoder) + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_beitv2(): + register_all_modules() + + model = BEiT( + backbone=backbone, + neck=neck, + head=head, + target_generator=target_generator, + data_preprocessor=data_preprocessor) + + fake_img = torch.rand((1, 3, 224, 224)) + fake_target_img = torch.rand((1, 3, 224, 224)) + fake_mask = torch.zeros((196)).bool() + fake_mask[75:150] = 1 + fake_data_sample = SelfSupDataSample() + fake_mask = InstanceData(value=fake_mask) + fake_data_sample.mask = fake_mask + fake_data_sample = [fake_data_sample] + + fake_data = { + 'inputs': [fake_img, fake_target_img], + 'data_sample': fake_data_sample + } + + fake_batch_inputs, fake_data_samples = model.data_preprocessor(fake_data) + fake_outputs = model(fake_batch_inputs, fake_data_samples, mode='loss') + assert isinstance(fake_outputs['loss_1'].item(), float) + assert isinstance(fake_outputs['loss_2'].item(), float) diff --git a/tests/test_models/test_backbones/test_beit_vit.py b/tests/test_models/test_backbones/test_beit_vit.py new file mode 100644 index 00000000..e2e877f7 --- /dev/null +++ b/tests/test_models/test_backbones/test_beit_vit.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest +import torch + +from mmselfsup.models.backbones import BEiTViT + +backbone = dict( + arch='base', + patch_size=16, + drop_path_rate=0.1, + final_norm=True, + layer_scale_init_value=0.1, +) + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_beit_vit(): + beit_backbone = BEiTViT(**backbone) + beit_backbone.init_weights() + + fake_inputs = torch.randn((2, 3, 224, 224)) + fake_mask = torch.zeros((2, 196)) + fake_mask[:, 75:150] = 1 + fake_outputs = beit_backbone(fake_inputs, fake_mask) + + assert list(fake_outputs[0].shape) == [2, 197, 768] diff --git a/tests/test_models/test_utils/test_dalle.py b/tests/test_models/test_target_generators/test_dalle.py similarity index 86% rename from tests/test_models/test_utils/test_dalle.py rename to tests/test_models/test_target_generators/test_dalle.py index 04aa9b78..911c052a 100644 --- a/tests/test_models/test_utils/test_dalle.py +++ b/tests/test_models/test_target_generators/test_dalle.py @@ -4,7 +4,7 @@ import platform import pytest import torch -from mmselfsup.models.utils import Encoder +from mmselfsup.models.target_generators import Encoder @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') diff --git a/tests/test_models/test_target_generators/test_vqkd.py b/tests/test_models/test_target_generators/test_vqkd.py new file mode 100644 index 00000000..29b64b26 --- /dev/null +++ b/tests/test_models/test_target_generators/test_vqkd.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest +import torch + +from mmselfsup.models.target_generators import VQKD + +vqkd_encoder = dict( + arch='base', + img_size=224, + patch_size=16, + in_channels=3, + out_indices=-1, + drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN', eps=1e-6), + final_norm=True, + with_cls_token=True, + avg_token=False, + frozen_stages=-1, + output_cls_token=False, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + use_shared_rel_pos_bias=False, + layer_scale_init_value=0., + interpolate_mode='bicubic', + patch_cfg=dict(), + layer_cfgs=dict(), + init_cfg=None) + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_vqkd(): + model = VQKD(encoder_config=vqkd_encoder) + fake_inputs = torch.rand((2, 3, 224, 224)) + fake_outputs = model(fake_inputs) + + assert list(fake_outputs.shape) == [2, 196] diff --git a/tests/test_models/test_utils/test_data_preprocessor.py b/tests/test_models/test_utils/test_data_preprocessor.py index 82a74f63..bf34a6aa 100644 --- a/tests/test_models/test_utils/test_data_preprocessor.py +++ b/tests/test_models/test_utils/test_data_preprocessor.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +import pytest import torch -from mmselfsup.models.utils import SelfSupDataPreprocessor +from mmselfsup.models.utils import (SelfSupDataPreprocessor, + TwoNormDataPreprocessor) from mmselfsup.structures import SelfSupDataSample @@ -16,3 +18,51 @@ def test_selfsup_data_preprocessor(): fake_batches, fake_samples = data_preprocessor(fake_data) assert len(fake_batches) == 1 assert len(fake_samples) == 2 + + +def test_two_norm_data_preprocessor(): + with pytest.raises(AssertionError): + data_preprocessor = TwoNormDataPreprocessor( + rgb_to_bgr=True, + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + ) + with pytest.raises(AssertionError): + data_preprocessor = TwoNormDataPreprocessor( + rgb_to_bgr=True, + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + second_mean=(127.5, 127.5), + second_std=(127.5, 127.5, 127.5), + ) + with pytest.raises(AssertionError): + data_preprocessor = TwoNormDataPreprocessor( + rgb_to_bgr=True, + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + second_mean=(127.5, 127.5, 127.5), + second_std=(127.5, 127.5), + ) + + data_preprocessor = dict( + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + second_mean=(127.5, 127.5, 127.5), + second_std=(127.5, 127.5, 127.5), + bgr_to_rgb=True) + + data_preprocessor = TwoNormDataPreprocessor(**data_preprocessor) + fake_data = { + 'inputs': + [torch.randn((4, 3, 224, 224)), + torch.randn((4, 3, 224, 224))], + 'data_sample': [ + SelfSupDataSample(), + SelfSupDataSample(), + SelfSupDataSample(), + SelfSupDataSample() + ] + } + fake_batches, fake_samples = data_preprocessor(fake_data) + assert len(fake_batches) == 2 + assert len(fake_samples) == 4