From fa53174fd9a2bad710e4049fe59ae90b26fe24a9 Mon Sep 17 00:00:00 2001 From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Date: Tue, 8 Aug 2023 16:01:07 +0800 Subject: [PATCH] [Feature]: Add MFF (#1725) * [Feature]: Add MFF * [Feature]: Add mff linear prob * [Feature]: Add ft * [Fix]: Update docstring * [Feature]: Update out_indices * [Feature]: Add prefix to ft * [Feature]: Add README * [Feature]: Update readme * [Feature]: Update README * [Feature]: Add metafile * [Feature]: Update README * [Fix]: Fix lint * [Feature]: Add UT * [Feature]: Update paper link --- README.md | 1 + README_zh-CN.md | 1 + configs/mff/README.md | 60 ++++++ .../vit-base-p16_8xb128-coslr-100e_in1k.py | 114 ++++++++++ ...-base-p16_8xb2048-linear-coslr-90e_in1k.py | 74 +++++++ configs/mff/metafile.yml | 103 ++++++++++ ...vit-base-p16_8xb512-amp-coslr-300e_in1k.py | 24 +++ ...vit-base-p16_8xb512-amp-coslr-800e_in1k.py | 24 +++ mmpretrain/datasets/__init__.py | 2 +- mmpretrain/datasets/transforms/__init__.py | 7 +- mmpretrain/datasets/transforms/processing.py | 53 +++++ .../models/backbones/vision_transformer.py | 2 +- mmpretrain/models/selfsup/__init__.py | 3 + mmpretrain/models/selfsup/mff.py | 194 ++++++++++++++++++ model-index.yml | 1 + tests/test_models/test_selfsup/test_mff.py | 63 ++++++ 16 files changed, 721 insertions(+), 5 deletions(-) create mode 100644 configs/mff/README.md create mode 100644 configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py create mode 100644 configs/mff/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py create mode 100644 configs/mff/metafile.yml create mode 100644 configs/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py create mode 100644 configs/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k.py create mode 100644 mmpretrain/models/selfsup/mff.py create mode 100644 tests/test_models/test_selfsup/test_mff.py diff --git a/README.md b/README.md index 6456ec28..cb17dfa1 100644 --- a/README.md +++ b/README.md @@ -253,6 +253,7 @@ Results and models are available in the [model zoo](https://mmpretrain.readthedo
  • MixMIM (arXiv'2022)
  • iTPN (CVPR'2023)
  • SparK (ICLR'2023)
  • +
  • MFF (ICCV'2023)
  • diff --git a/README_zh-CN.md b/README_zh-CN.md index 02c62c6f..696215dd 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -249,6 +249,7 @@ mim install -e ".[multimodal]"
  • MixMIM (arXiv'2022)
  • iTPN (CVPR'2023)
  • SparK (ICLR'2023)
  • +
  • MFF (ICCV'2023)
  • diff --git a/configs/mff/README.md b/configs/mff/README.md new file mode 100644 index 00000000..7001c74b --- /dev/null +++ b/configs/mff/README.md @@ -0,0 +1,60 @@ +# MFF + +> [Improving Pixel-based MIM by Reducing Wasted Modeling Capability](https://arxiv.org/abs/2308.00261) + + + +## Abstract + +There has been significant progress in Masked Image Modeling (MIM). Existing MIM methods can be broadly categorized into two groups based on the reconstruction target: pixel-based and tokenizer-based approaches. The former offers a simpler pipeline and lower computational cost, but it is known to be biased toward high-frequency details. In this paper, we provide a set of empirical studies to confirm this limitation of pixel-based MIM and propose a new method that explicitly utilizes low-level features from shallow layers to aid pixel reconstruction. By incorporating this design into our base method, MAE, we reduce the wasted modeling capability of pixel-based MIM, improving its convergence and achieving non-trivial improvements across various downstream tasks. To the best of our knowledge, we are the first to systematically investigate multi-level feature fusion for isotropic architectures like the standard Vision Transformer (ViT). Notably, when applied to a smaller model (e.g., ViT-S), our method yields significant performance gains, such as 1.2% on fine-tuning, 2.8% on linear probing, and 2.6% on semantic segmentation. + +
    + +
    + +**Train/Test Command** + +Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset). + +Train: + +```shell +python tools/train.py configs/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py +``` + +Test: + +```shell +python tools/test.py configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py None +``` + + + +## Models and results + +### Pretrained models + +| Model | Params (M) | Flops (G) | Config | Download | +| :-------------------------------------------- | :--------: | :-------: | :------------------------------------------------------: | :------------------------------------------------------------------------------: | +| `mff_vit-base-p16_8xb512-amp-coslr-300e_in1k` | - | - | [config](mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.json) | +| `mff_vit-base-p16_8xb512-amp-coslr-800e_in1k` | - | - | [config](mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.json) | + +### Image Classification on ImageNet-1k + +| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Config | Download | +| :---------------------------------------- | :------------------------------------------: | :--------: | :-------: | :-------: | :----------------------------------------: | :-------------------------------------------: | +| `vit-base-p16_mff-300e-pre_8xb128-coslr-100e_in1k` | [MFF 300-Epochs](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth) | 86.57 | 17.58 | 83.00 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.pth) / [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.json) | +| `vit-base-p16_mff-800e-pre_8xb128-coslr-100e_in1k` | [MFF 800-Epochs](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.pth) | 86.57 | 17.58 | 83.70 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb128-coslr-100e/vit-base-p16_8xb128-coslr-100e_20230802-6780e47d.pth) / [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb128-coslr-100e/vit-base-p16_8xb128-coslr-100e_20230802-6780e47d.json) | +| `vit-base-p16_mff-300e-pre_8xb2048-linear-coslr-90e_in1k` | [MFF 300-Epochs](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth) | 304.33 | 61.60 | 64.20 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb2048-linear-coslr-90e_in1k/vit-base-p16_8xb2048-linear-coslr-90e_in1k.json) | +| `vit-base-p16_mff-800e-pre_8xb2048-linear-coslr-90e_in1k` | [MFF 800-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k_20220825-f7569ca2.pth) | 304.33 | 61.60 | 68.30 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb2048-linear-coslr-90e/vit-base-p16_8xb2048-linear-coslr-90e_20230802-6b1f7bc8.pth) / [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb2048-linear-coslr-90e/vit-base-p16_8xb2048-linear-coslr-90e_20230802-6b1f7bc8.json) | + +## Citation + +```bibtex +@article{MFF, + title={Improving Pixel-based MIM by Reducing Wasted Modeling Capability}, + author={Yuan Liu, Songyang Zhang, Jiacheng Chen, Zhaohui Yu, Kai Chen, Dahua Lin}, + journal={arXiv}, + year={2023} +} +``` diff --git a/configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py b/configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py new file mode 100644 index 00000000..4cf9ca11 --- /dev/null +++ b/configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py @@ -0,0 +1,114 @@ +_base_ = [ + '../../_base_/datasets/imagenet_bs64_swin_224.py', + '../../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../../_base_/default_runtime.py' +] + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies='timm_increasing', + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=0.3333333333333333, + fill_color=[103.53, 116.28, 123.675], + fill_std=[57.375, 57.12, 58.395]), + dict(type='PackInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='ResizeEdge', + scale=256, + edge='short', + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='PackInputs') +] + +train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='VisionTransformer', + arch='base', + img_size=224, + patch_size=16, + drop_path_rate=0.1, + out_type='avg_featmap', + final_norm=False, + init_cfg=dict(type='Pretrained', checkpoint='', prefix='backbone.')), + neck=None, + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=768, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)]), + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) + ])) + +# optimizer wrapper +optim_wrapper = dict( + optimizer=dict( + type='AdamW', lr=2e-3, weight_decay=0.05, betas=(0.9, 0.999)), + constructor='LearningRateDecayOptimWrapperConstructor', + paramwise_cfg=dict( + layer_decay_rate=0.65, + custom_keys={ + '.ln': dict(decay_mult=0.0), + '.bias': dict(decay_mult=0.0), + '.cls_token': dict(decay_mult=0.0), + '.pos_embed': dict(decay_mult=0.0) + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=5, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=95, + by_epoch=True, + begin=5, + end=100, + eta_min=1e-6, + convert_to_iter_based=True) +] + +# runtime settings +default_hooks = dict( + # save checkpoint per epoch. + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3)) + +train_cfg = dict(by_epoch=True, max_epochs=100) + +randomness = dict(seed=0, diff_rank_seed=True) diff --git a/configs/mff/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py b/configs/mff/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py new file mode 100644 index 00000000..dc5f2307 --- /dev/null +++ b/configs/mff/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py @@ -0,0 +1,74 @@ +_base_ = [ + '../../_base_/datasets/imagenet_bs32_pil_resize.py', + '../../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../../_base_/default_runtime.py' +] + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='ToPIL', to_rgb=True), + dict(type='MAERandomResizedCrop', size=224, interpolation=3), + dict(type='torchvision/RandomHorizontalFlip', p=0.5), + dict(type='ToNumpy', to_bgr=True), + dict(type='PackInputs'), +] + +# dataset settings +train_dataloader = dict( + batch_size=2048, drop_last=True, dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(drop_last=False) +test_dataloader = dict(drop_last=False) + +# model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='VisionTransformer', + arch='base', + img_size=224, + patch_size=16, + frozen_stages=12, + out_type='cls_token', + final_norm=True, + init_cfg=dict(type='Pretrained', prefix='backbone.')), + neck=dict(type='ClsBatchNormNeck', input_features=768), + head=dict( + type='VisionTransformerClsHead', + num_classes=1000, + in_channels=768, + loss=dict(type='CrossEntropyLoss'), + init_cfg=[dict(type='TruncNormal', layer='Linear', std=0.01)])) + +# optimizer +optim_wrapper = dict( + _delete_=True, + type='AmpOptimWrapper', + optimizer=dict(type='LARS', lr=6.4, weight_decay=0.0, momentum=0.9)) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=10, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=80, + by_epoch=True, + begin=10, + end=90, + eta_min=0.0, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(by_epoch=True, max_epochs=90) + +default_hooks = dict( + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1), + logger=dict(type='LoggerHook', interval=10)) + +randomness = dict(seed=0, diff_rank_seed=True) diff --git a/configs/mff/metafile.yml b/configs/mff/metafile.yml new file mode 100644 index 00000000..f1da4cc4 --- /dev/null +++ b/configs/mff/metafile.yml @@ -0,0 +1,103 @@ +Collections: + - Name: MFF + Metadata: + Training Data: ImageNet-1k + Training Techniques: + - AdamW + Training Resources: 8x A100-80G GPUs + Architecture: + - ViT + Paper: + Title: Improving Pixel-based MIM by Reducing Wasted Modeling Capability + URL: https://arxiv.org/pdf/2308.00261.pdf + README: configs/mff/README.md + +Models: + - Name: mff_vit-base-p16_8xb512-amp-coslr-300e_in1k + Metadata: + Epochs: 300 + Batch Size: 2048 + FLOPs: 17581972224 + Parameters: 85882692 + Training Data: ImageNet-1k + In Collection: MaskFeat + Results: null + Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth + Config: configs/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py + Downstream: + - vit-base-p16_mff-300e-pre_8xb128-coslr-100e_in1k + - vit-base-p16_mff-300e-pre_8xb2048-linear-coslr-90e_in1k + - Name: mff_vit-base-p16_8xb512-amp-coslr-800e_in1k + Metadata: + Epochs: 800 + Batch Size: 2048 + FLOPs: 17581972224 + Parameters: 85882692 + Training Data: ImageNet-1k + In Collection: MaskFeat + Results: null + Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.pth + Config: configs/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k.py + Downstream: + - vit-base-p16_mff-800e-pre_8xb128-coslr-100e_in1k + - vit-base-p16_mff-800e-pre_8xb2048-linear-coslr-90e_in1k + - Name: vit-base-p16_mff-300e-pre_8xb128-coslr-100e_in1k + Metadata: + Epochs: 100 + Batch Size: 1024 + FLOPs: 17581215744 + Parameters: 86566120 + Training Data: ImageNet-1k + In Collection: MaskFeat + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 83.0 + Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.pth + Config: configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py + - Name: vit-base-p16_mff-800e-pre_8xb128-coslr-100e_in1k + Metadata: + Epochs: 100 + Batch Size: 1024 + FLOPs: 17581215744 + Parameters: 86566120 + Training Data: ImageNet-1k + In Collection: MFF + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 83.7 + Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb128-coslr-100e/vit-base-p16_8xb128-coslr-100e_20230802-6780e47d.pth + Config: configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py + - Name: vit-base-p16_mff-300e-pre_8xb2048-linear-coslr-90e_in1k + Metadata: + Epochs: 90 + Batch Size: 16384 + FLOPs: 17581215744 + Parameters: 86566120 + Training Data: ImageNet-1k + In Collection: MFF + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 64.2 + Weights: + Config: configs/mff/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py + - Name: vit-base-p16_mff-800e-pre_8xb2048-linear-coslr-90e_in1k + Metadata: + Epochs: 90 + Batch Size: 16384 + FLOPs: 17581215744 + Parameters: 86566120 + Training Data: ImageNet-1k + In Collection: MFF + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 68.3 + Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.pth + Config: configs/mff/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py diff --git a/configs/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py b/configs/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py new file mode 100644 index 00000000..f9fc5219 --- /dev/null +++ b/configs/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py @@ -0,0 +1,24 @@ +_base_ = '../mae/mae_vit-base-p16_8xb512-amp-coslr-300e_in1k.py' + +randomness = dict(seed=2, diff_rank_seed=True) + +# dataset config +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='ToPIL', to_rgb=True), + dict(type='torchvision/Resize', size=224), + dict( + type='torchvision/RandomCrop', + size=224, + padding=4, + padding_mode='reflect'), + dict(type='torchvision/RandomHorizontalFlip', p=0.5), + dict(type='ToNumpy', to_bgr=True), + dict(type='PackInputs') +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +# model config +model = dict( + type='MFF', backbone=dict(type='MFFViT', out_indices=[0, 2, 4, 6, 8, 11])) diff --git a/configs/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k.py b/configs/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k.py new file mode 100644 index 00000000..d8976b22 --- /dev/null +++ b/configs/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k.py @@ -0,0 +1,24 @@ +_base_ = '../mae/mae_vit-base-p16_8xb512-amp-coslr-800e_in1k.py' + +randomness = dict(seed=2, diff_rank_seed=True) + +# dataset config +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='ToPIL', to_rgb=True), + dict(type='torchvision/Resize', size=224), + dict( + type='torchvision/RandomCrop', + size=224, + padding=4, + padding_mode='reflect'), + dict(type='torchvision/RandomHorizontalFlip', p=0.5), + dict(type='ToNumpy', to_bgr=True), + dict(type='PackInputs') +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +# model config +model = dict( + type='MFF', backbone=dict(type='MFFViT', out_indices=[0, 2, 4, 6, 8, 11])) diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py index 657ad92f..29753d70 100644 --- a/mmpretrain/datasets/__init__.py +++ b/mmpretrain/datasets/__init__.py @@ -41,8 +41,8 @@ if WITH_MULTIMODAL: from .flickr30k_caption import Flickr30kCaption from .flickr30k_retrieval import Flickr30kRetrieval from .gqa_dataset import GQA - from .infographic_vqa import InfographicVQA from .iconqa import IconQA + from .infographic_vqa import InfographicVQA from .nocaps import NoCaps from .ocr_vqa import OCRVQA from .refcoco import RefCOCO diff --git a/mmpretrain/datasets/transforms/__init__.py b/mmpretrain/datasets/transforms/__init__.py index 8cda99d5..617503f2 100644 --- a/mmpretrain/datasets/transforms/__init__.py +++ b/mmpretrain/datasets/transforms/__init__.py @@ -12,8 +12,9 @@ from .formatting import (Collect, NumpyToPIL, PackInputs, PackMultiTaskInputs, PILToNumpy, Transpose) from .processing import (Albumentations, BEiTMaskGenerator, CleanCaption, ColorJitter, EfficientNetCenterCrop, - EfficientNetRandomCrop, Lighting, RandomCrop, - RandomErasing, RandomResizedCrop, + EfficientNetRandomCrop, Lighting, + MAERandomResizedCrop, RandomCrop, RandomErasing, + RandomResizedCrop, RandomResizedCropAndInterpolationWithTwoPic, RandomTranslatePad, ResizeEdge, SimMIMMaskGenerator) from .utils import get_transform_idx, remove_transform @@ -36,5 +37,5 @@ __all__ = [ 'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView', 'ApplyToList', 'CleanCaption', 'RandomTranslatePad', 'RandomResizedCropAndInterpolationWithTwoPic', 'get_transform_idx', - 'remove_transform' + 'remove_transform', 'MAERandomResizedCrop' ] diff --git a/mmpretrain/datasets/transforms/processing.py b/mmpretrain/datasets/transforms/processing.py index dfd3ed6b..4c640f6b 100644 --- a/mmpretrain/datasets/transforms/processing.py +++ b/mmpretrain/datasets/transforms/processing.py @@ -11,9 +11,13 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union import mmcv import mmengine import numpy as np +import torch import torchvision +import torchvision.transforms.functional as F from mmcv.transforms import BaseTransform from mmcv.transforms.utils import cache_randomness +from PIL import Image +from torchvision import transforms from torchvision.transforms.transforms import InterpolationMode from mmpretrain.registry import TRANSFORMS @@ -1740,3 +1744,52 @@ class RandomTranslatePad(BaseTransform): results['gt_bboxes'][i] = box return results + + +@TRANSFORMS.register_module() +class MAERandomResizedCrop(transforms.RandomResizedCrop): + """RandomResizedCrop for matching TF/TPU implementation: no for-loop is + used. + + This may lead to results different with torchvision's version. + Following BYOL's TF code: + https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 # noqa: E501 + """ + + @staticmethod + def get_params(img: Image.Image, scale: tuple, ratio: tuple) -> Tuple: + width, height = img.size + area = height * width + + target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + log_ratio = torch.log(torch.tensor(ratio)) + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + w = min(w, width) + h = min(h, height) + + i = torch.randint(0, height - h + 1, size=(1, )).item() + j = torch.randint(0, width - w + 1, size=(1, )).item() + + return i, j, h, w + + def forward(self, results: dict) -> dict: + """The forward function of MAERandomResizedCrop. + + Args: + results (dict): The results dict contains the image and all these + information related to the image. + + Returns: + dict: The results dict contains the cropped image and all these + information related to the image. + """ + img = results['img'] + i, j, h, w = self.get_params(img, self.scale, self.ratio) + img = F.resized_crop(img, i, j, h, w, self.size, self.interpolation) + results['img'] = img + return results diff --git a/mmpretrain/models/backbones/vision_transformer.py b/mmpretrain/models/backbones/vision_transformer.py index 0e9efa34..a54053c2 100644 --- a/mmpretrain/models/backbones/vision_transformer.py +++ b/mmpretrain/models/backbones/vision_transformer.py @@ -436,7 +436,7 @@ class VisionTransformer(BaseBackbone): for param in self.pre_norm.parameters(): param.requires_grad = False # freeze cls_token - if self.cls_token: + if self.cls_token is not None: self.cls_token.requires_grad = False # freeze layers for i in range(1, self.frozen_stages + 1): diff --git a/mmpretrain/models/selfsup/__init__.py b/mmpretrain/models/selfsup/__init__.py index 1052dedc..08c1ed59 100644 --- a/mmpretrain/models/selfsup/__init__.py +++ b/mmpretrain/models/selfsup/__init__.py @@ -9,6 +9,7 @@ from .eva import EVA from .itpn import iTPN, iTPNHiViT from .mae import MAE, MAEHiViT, MAEViT from .maskfeat import HOGGenerator, MaskFeat, MaskFeatViT +from .mff import MFF, MFFViT from .milan import MILAN, CLIPGenerator, MILANViT from .mixmim import MixMIM, MixMIMPretrainTransformer from .moco import MoCo @@ -53,4 +54,6 @@ __all__ = [ 'BarlowTwins', 'SwAV', 'SparK', + 'MFF', + 'MFFViT', ] diff --git a/mmpretrain/models/selfsup/mff.py b/mmpretrain/models/selfsup/mff.py new file mode 100644 index 00000000..26850580 --- /dev/null +++ b/mmpretrain/models/selfsup/mff.py @@ -0,0 +1,194 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F + +from mmpretrain.models.selfsup.mae import MAE, MAEViT +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample + + +@MODELS.register_module() +class MFFViT(MAEViT): + """Vision Transformer for MFF Pretraining. + + This class inherits all these functionalities from ``MAEViT``, and + add multi-level feature fusion to it. For more details, you can + refer to `Improving Pixel-based MIM by Reducing Wasted Modeling + Capability`. + + Args: + arch (str | dict): Vision Transformer architecture + Default: 'b' + img_size (int | tuple): Input image size + patch_size (int | tuple): The patch size + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Defaults to True. + out_type (str): The type of output features. Please choose from + + - ``"cls_token"``: The class token tensor with shape (B, C). + - ``"featmap"``: The feature map tensor from the patch tokens + with shape (B, C, H, W). + - ``"avg_featmap"``: The global averaged feature map tensor + with shape (B, C). + - ``"raw"``: The raw feature tensor includes patch tokens and + class tokens with shape (B, L, C). + + It only works without input mask. Defaults to ``"avg_featmap"``. + interpolate_mode (str): Select the interpolate mode for position + embeding vector resize. Defaults to "bicubic". + patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict. + layer_cfgs (Sequence | dict): Configs of each transformer layer in + encoder. Defaults to an empty dict. + mask_ratio (bool): The ratio of total number of patches to be masked. + Defaults to 0.75. + init_cfg (Union[List[dict], dict], optional): Initialization config + dict. Defaults to None. + """ + + def __init__(self, + arch: Union[str, dict] = 'b', + img_size: int = 224, + patch_size: int = 16, + out_indices: Union[Sequence, int] = -1, + drop_rate: float = 0, + drop_path_rate: float = 0, + norm_cfg: dict = dict(type='LN', eps=1e-6), + final_norm: bool = True, + out_type: str = 'raw', + interpolate_mode: str = 'bicubic', + patch_cfg: dict = dict(), + layer_cfgs: dict = dict(), + mask_ratio: float = 0.75, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__( + arch=arch, + img_size=img_size, + patch_size=patch_size, + out_indices=out_indices, + drop_rate=drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + final_norm=final_norm, + out_type=out_type, + interpolate_mode=interpolate_mode, + patch_cfg=patch_cfg, + layer_cfgs=layer_cfgs, + mask_ratio=mask_ratio, + init_cfg=init_cfg) + proj_layers = [ + torch.nn.Linear(self.embed_dims, self.embed_dims) + for _ in range(len(self.out_indices) - 1) + ] + self.proj_layers = torch.nn.ModuleList(proj_layers) + self.proj_weights = torch.nn.Parameter( + torch.ones(len(self.out_indices)).view(-1, 1, 1, 1)) + if len(self.out_indices) == 1: + self.proj_weights.requires_grad = False + + def forward( + self, + x: torch.Tensor, + mask: Optional[bool] = True + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate features for masked images. + + The function supports two kind of forward behaviors. If the ``mask`` is + ``True``, the function will generate mask to masking some patches + randomly and get the hidden features for visible patches, which means + the function will be executed as masked imagemodeling pre-training; + if the ``mask`` is ``None`` or ``False``, the forward function will + call ``super().forward()``, which extract features from images without + mask. + + + Args: + x (torch.Tensor): Input images, which is of shape B x C x H x W. + mask (bool, optional): To indicate whether the forward function + generating ``mask`` or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, + mask and the ids to restore original image. + + - ``x`` (torch.Tensor): hidden features, which is of shape + B x (L * mask_ratio) x C. + - ``mask`` (torch.Tensor): mask used to mask image. + - ``ids_restore`` (torch.Tensor): ids to restore original image. + """ + if mask is None or False: + return super().forward(x) + + else: + B = x.shape[0] + x = self.patch_embed(x)[0] + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, mask, ids_restore = self.random_masking(x, self.mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + res = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + if i != self.out_indices[-1]: + proj_x = self.proj_layers[self.out_indices.index(i)](x) + else: + proj_x = x + res.append(proj_x) + res = torch.stack(res) + proj_weights = F.softmax(self.proj_weights, dim=0) + res = res * proj_weights + res = res.sum(dim=0) + + # Use final norm + x = self.norm1(res) + return (x, mask, ids_restore, proj_weights.view(-1)) + + +@MODELS.register_module() +class MFF(MAE): + """MFF. + + Implementation of `Improving Pixel-based MIM by Reducing Wasted Modeling + Capability`. + """ + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (torch.Tensor): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + # ids_restore: the same as that in original repo, which is used + # to recover the original order of tokens in decoder. + latent, mask, ids_restore, weights = self.backbone(inputs) + pred = self.neck(latent, ids_restore) + loss = self.head.loss(pred, inputs, mask) + weight_params = { + f'weight_{i}': weights[i] + for i in range(weights.size(0)) + } + losses = dict(loss=loss) + losses.update(weight_params) + return losses diff --git a/model-index.yml b/model-index.yml index 3fb3d045..1bd92853 100644 --- a/model-index.yml +++ b/model-index.yml @@ -82,3 +82,4 @@ Import: - configs/minigpt4/metafile.yml - configs/llava/metafile.yml - configs/otter/metafile.yml + - configs/mff/metafile.yml diff --git a/tests/test_models/test_selfsup/test_mff.py b/tests/test_models/test_selfsup/test_mff.py new file mode 100644 index 00000000..3ad0295f --- /dev/null +++ b/tests/test_models/test_selfsup/test_mff.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest +import torch + +from mmpretrain.models import MFF, MFFViT +from mmpretrain.structures import DataSample + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_mae_vit(): + backbone = dict( + arch='b', patch_size=16, mask_ratio=0.75, out_indices=[1, 11]) + mae_backbone = MFFViT(**backbone) + mae_backbone.init_weights() + fake_inputs = torch.randn((2, 3, 224, 224)) + + # test with mask + fake_outputs = mae_backbone(fake_inputs)[0] + assert list(fake_outputs.shape) == [2, 50, 768] + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_mae(): + data_preprocessor = { + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'to_rgb': True + } + backbone = dict( + type='MFFViT', + arch='b', + patch_size=16, + mask_ratio=0.75, + out_indices=[1, 11]) + neck = dict( + type='MAEPretrainDecoder', + patch_size=16, + in_chans=3, + embed_dim=768, + decoder_embed_dim=512, + decoder_depth=8, + decoder_num_heads=16, + mlp_ratio=4., + ) + loss = dict(type='PixelReconstructionLoss', criterion='L2') + head = dict( + type='MAEPretrainHead', norm_pix=False, patch_size=16, loss=loss) + + alg = MFF( + backbone=backbone, + neck=neck, + head=head, + data_preprocessor=data_preprocessor) + + fake_data = { + 'inputs': torch.randn((2, 3, 224, 224)), + 'data_samples': [DataSample() for _ in range(2)] + } + fake_inputs = alg.data_preprocessor(fake_data) + fake_outputs = alg(**fake_inputs, mode='loss') + assert isinstance(fake_outputs['loss'].item(), float)