diff --git a/README.md b/README.md
index 6456ec28..cb17dfa1 100644
--- a/README.md
+++ b/README.md
@@ -253,6 +253,7 @@ Results and models are available in the [model zoo](https://mmpretrain.readthedo
         <li><a href="configs/mixmim">MixMIM (arXiv'2022)</a></li>
         <li><a href="configs/itpn">iTPN (CVPR'2023)</a></li>
         <li><a href="configs/spark">SparK (ICLR'2023)</a></li>
+        <li><a href="configs/mff">MFF (ICCV'2023)</a></li>
         </ul>
       </td>
       <td>
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 02c62c6f..696215dd 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -249,6 +249,7 @@ mim install -e ".[multimodal]"
         <li><a href="configs/mixmim">MixMIM (arXiv'2022)</a></li>
         <li><a href="configs/itpn">iTPN (CVPR'2023)</a></li>
         <li><a href="configs/spark">SparK (ICLR'2023)</a></li>
+        <li><a href="configs/mff">MFF (ICCV'2023)</a></li>
         </ul>
       </td>
       <td>
diff --git a/configs/mff/README.md b/configs/mff/README.md
new file mode 100644
index 00000000..7001c74b
--- /dev/null
+++ b/configs/mff/README.md
@@ -0,0 +1,60 @@
+# MFF
+
+> [Improving Pixel-based MIM by Reducing Wasted Modeling Capability](https://arxiv.org/abs/2308.00261)
+
+<!-- [ALGORITHM] -->
+
+## Abstract
+
+There has been significant progress in Masked Image Modeling (MIM). Existing MIM methods can be broadly categorized into two groups based on the reconstruction target: pixel-based and tokenizer-based approaches. The former offers a simpler pipeline and lower computational cost, but it is known to be biased toward high-frequency details. In this paper, we provide a set of empirical studies to confirm this limitation of pixel-based MIM and propose a new method that explicitly utilizes low-level features from shallow layers to aid pixel reconstruction. By incorporating this design into our base method, MAE, we reduce the wasted modeling capability of pixel-based MIM, improving its convergence and achieving non-trivial improvements across various downstream tasks. To the best of our knowledge, we are the first to systematically investigate multi-level feature fusion for isotropic architectures like the standard Vision Transformer (ViT). Notably, when applied to a smaller model (e.g., ViT-S), our method yields significant performance gains, such as 1.2% on fine-tuning, 2.8% on linear probing, and 2.6% on semantic segmentation.
+
+<div align=center>
+<img src="https://user-images.githubusercontent.com/30762564/257412932-5f36b11b-ee64-4ce7-b7d1-a31000302bd8.png" width="80%"/>
+</div>
+
+**Train/Test Command**
+
+Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
+
+Train:
+
+```shell
+python tools/train.py configs/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py
+```
+
+Test:
+
+```shell
+python tools/test.py configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py None
+```
+
+<!-- [TABS-END] -->
+
+## Models and results
+
+### Pretrained models
+
+| Model                                         | Params (M) | Flops (G) |                          Config                          |                                     Download                                     |
+| :-------------------------------------------- | :--------: | :-------: | :------------------------------------------------------: | :------------------------------------------------------------------------------: |
+| `mff_vit-base-p16_8xb512-amp-coslr-300e_in1k` |     -      |     -     | [config](mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.json) |
+| `mff_vit-base-p16_8xb512-amp-coslr-800e_in1k` |     -      |     -     | [config](mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.json) |
+
+### Image Classification on ImageNet-1k
+
+| Model                                     |                   Pretrain                   | Params (M) | Flops (G) | Top-1 (%) |                   Config                   |                   Download                    |
+| :---------------------------------------- | :------------------------------------------: | :--------: | :-------: | :-------: | :----------------------------------------: | :-------------------------------------------: |
+| `vit-base-p16_mff-300e-pre_8xb128-coslr-100e_in1k` | [MFF 300-Epochs](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth) |   86.57    |   17.58   |   83.00   | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.pth)  /   [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.json) |
+| `vit-base-p16_mff-800e-pre_8xb128-coslr-100e_in1k` | [MFF 800-Epochs](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.pth) |   86.57    |   17.58   |   83.70   | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb128-coslr-100e/vit-base-p16_8xb128-coslr-100e_20230802-6780e47d.pth) / [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb128-coslr-100e/vit-base-p16_8xb128-coslr-100e_20230802-6780e47d.json) |
+| `vit-base-p16_mff-300e-pre_8xb2048-linear-coslr-90e_in1k` | [MFF 300-Epochs](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth) |   304.33   |   61.60   |   64.20   | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb2048-linear-coslr-90e_in1k/vit-base-p16_8xb2048-linear-coslr-90e_in1k.json) |
+| `vit-base-p16_mff-800e-pre_8xb2048-linear-coslr-90e_in1k` | [MFF 800-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k_20220825-f7569ca2.pth) |   304.33   |   61.60   |   68.30   | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb2048-linear-coslr-90e/vit-base-p16_8xb2048-linear-coslr-90e_20230802-6b1f7bc8.pth)  / [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb2048-linear-coslr-90e/vit-base-p16_8xb2048-linear-coslr-90e_20230802-6b1f7bc8.json) |
+
+## Citation
+
+```bibtex
+@article{MFF,
+  title={Improving Pixel-based MIM by Reducing Wasted Modeling Capability},
+  author={Yuan Liu, Songyang Zhang, Jiacheng Chen, Zhaohui Yu, Kai Chen, Dahua Lin},
+  journal={arXiv},
+  year={2023}
+}
+```
diff --git a/configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py b/configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py
new file mode 100644
index 00000000..4cf9ca11
--- /dev/null
+++ b/configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py
@@ -0,0 +1,114 @@
+_base_ = [
+    '../../_base_/datasets/imagenet_bs64_swin_224.py',
+    '../../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+    '../../_base_/default_runtime.py'
+]
+
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='RandomResizedCrop',
+        scale=224,
+        backend='pillow',
+        interpolation='bicubic'),
+    dict(type='RandomFlip', prob=0.5, direction='horizontal'),
+    dict(
+        type='RandAugment',
+        policies='timm_increasing',
+        num_policies=2,
+        total_level=10,
+        magnitude_level=9,
+        magnitude_std=0.5,
+        hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')),
+    dict(
+        type='RandomErasing',
+        erase_prob=0.25,
+        mode='rand',
+        min_area_ratio=0.02,
+        max_area_ratio=0.3333333333333333,
+        fill_color=[103.53, 116.28, 123.675],
+        fill_std=[57.375, 57.12, 58.395]),
+    dict(type='PackInputs')
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='ResizeEdge',
+        scale=256,
+        edge='short',
+        backend='pillow',
+        interpolation='bicubic'),
+    dict(type='CenterCrop', crop_size=224),
+    dict(type='PackInputs')
+]
+
+train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline))
+val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline))
+test_dataloader = val_dataloader
+
+# model settings
+model = dict(
+    type='ImageClassifier',
+    backbone=dict(
+        type='VisionTransformer',
+        arch='base',
+        img_size=224,
+        patch_size=16,
+        drop_path_rate=0.1,
+        out_type='avg_featmap',
+        final_norm=False,
+        init_cfg=dict(type='Pretrained', checkpoint='', prefix='backbone.')),
+    neck=None,
+    head=dict(
+        type='LinearClsHead',
+        num_classes=1000,
+        in_channels=768,
+        loss=dict(
+            type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+        init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)]),
+    train_cfg=dict(augments=[
+        dict(type='Mixup', alpha=0.8),
+        dict(type='CutMix', alpha=1.0)
+    ]))
+
+# optimizer wrapper
+optim_wrapper = dict(
+    optimizer=dict(
+        type='AdamW', lr=2e-3, weight_decay=0.05, betas=(0.9, 0.999)),
+    constructor='LearningRateDecayOptimWrapperConstructor',
+    paramwise_cfg=dict(
+        layer_decay_rate=0.65,
+        custom_keys={
+            '.ln': dict(decay_mult=0.0),
+            '.bias': dict(decay_mult=0.0),
+            '.cls_token': dict(decay_mult=0.0),
+            '.pos_embed': dict(decay_mult=0.0)
+        }))
+
+# learning rate scheduler
+param_scheduler = [
+    dict(
+        type='LinearLR',
+        start_factor=1e-4,
+        by_epoch=True,
+        begin=0,
+        end=5,
+        convert_to_iter_based=True),
+    dict(
+        type='CosineAnnealingLR',
+        T_max=95,
+        by_epoch=True,
+        begin=5,
+        end=100,
+        eta_min=1e-6,
+        convert_to_iter_based=True)
+]
+
+# runtime settings
+default_hooks = dict(
+    # save checkpoint per epoch.
+    checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
+
+train_cfg = dict(by_epoch=True, max_epochs=100)
+
+randomness = dict(seed=0, diff_rank_seed=True)
diff --git a/configs/mff/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py b/configs/mff/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py
new file mode 100644
index 00000000..dc5f2307
--- /dev/null
+++ b/configs/mff/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py
@@ -0,0 +1,74 @@
+_base_ = [
+    '../../_base_/datasets/imagenet_bs32_pil_resize.py',
+    '../../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+    '../../_base_/default_runtime.py'
+]
+
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='ToPIL', to_rgb=True),
+    dict(type='MAERandomResizedCrop', size=224, interpolation=3),
+    dict(type='torchvision/RandomHorizontalFlip', p=0.5),
+    dict(type='ToNumpy', to_bgr=True),
+    dict(type='PackInputs'),
+]
+
+# dataset settings
+train_dataloader = dict(
+    batch_size=2048, drop_last=True, dataset=dict(pipeline=train_pipeline))
+val_dataloader = dict(drop_last=False)
+test_dataloader = dict(drop_last=False)
+
+# model settings
+model = dict(
+    type='ImageClassifier',
+    backbone=dict(
+        type='VisionTransformer',
+        arch='base',
+        img_size=224,
+        patch_size=16,
+        frozen_stages=12,
+        out_type='cls_token',
+        final_norm=True,
+        init_cfg=dict(type='Pretrained', prefix='backbone.')),
+    neck=dict(type='ClsBatchNormNeck', input_features=768),
+    head=dict(
+        type='VisionTransformerClsHead',
+        num_classes=1000,
+        in_channels=768,
+        loss=dict(type='CrossEntropyLoss'),
+        init_cfg=[dict(type='TruncNormal', layer='Linear', std=0.01)]))
+
+# optimizer
+optim_wrapper = dict(
+    _delete_=True,
+    type='AmpOptimWrapper',
+    optimizer=dict(type='LARS', lr=6.4, weight_decay=0.0, momentum=0.9))
+
+# learning rate scheduler
+param_scheduler = [
+    dict(
+        type='LinearLR',
+        start_factor=1e-4,
+        by_epoch=True,
+        begin=0,
+        end=10,
+        convert_to_iter_based=True),
+    dict(
+        type='CosineAnnealingLR',
+        T_max=80,
+        by_epoch=True,
+        begin=10,
+        end=90,
+        eta_min=0.0,
+        convert_to_iter_based=True)
+]
+
+# runtime settings
+train_cfg = dict(by_epoch=True, max_epochs=90)
+
+default_hooks = dict(
+    checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
+    logger=dict(type='LoggerHook', interval=10))
+
+randomness = dict(seed=0, diff_rank_seed=True)
diff --git a/configs/mff/metafile.yml b/configs/mff/metafile.yml
new file mode 100644
index 00000000..f1da4cc4
--- /dev/null
+++ b/configs/mff/metafile.yml
@@ -0,0 +1,103 @@
+Collections:
+  - Name: MFF
+    Metadata:
+      Training Data: ImageNet-1k
+      Training Techniques:
+        - AdamW
+      Training Resources: 8x A100-80G GPUs
+      Architecture:
+        - ViT
+    Paper:
+      Title: Improving Pixel-based MIM by Reducing Wasted Modeling Capability
+      URL: https://arxiv.org/pdf/2308.00261.pdf
+    README: configs/mff/README.md
+
+Models:
+  - Name: mff_vit-base-p16_8xb512-amp-coslr-300e_in1k
+    Metadata:
+      Epochs: 300
+      Batch Size: 2048
+      FLOPs: 17581972224
+      Parameters: 85882692
+      Training Data: ImageNet-1k
+    In Collection: MaskFeat
+    Results: null
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth
+    Config: configs/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py
+    Downstream:
+      - vit-base-p16_mff-300e-pre_8xb128-coslr-100e_in1k
+      - vit-base-p16_mff-300e-pre_8xb2048-linear-coslr-90e_in1k
+  - Name: mff_vit-base-p16_8xb512-amp-coslr-800e_in1k
+    Metadata:
+      Epochs: 800
+      Batch Size: 2048
+      FLOPs: 17581972224
+      Parameters: 85882692
+      Training Data: ImageNet-1k
+    In Collection: MaskFeat
+    Results: null
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.pth
+    Config: configs/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k.py
+    Downstream:
+      - vit-base-p16_mff-800e-pre_8xb128-coslr-100e_in1k
+      - vit-base-p16_mff-800e-pre_8xb2048-linear-coslr-90e_in1k
+  - Name: vit-base-p16_mff-300e-pre_8xb128-coslr-100e_in1k
+    Metadata:
+      Epochs: 100
+      Batch Size: 1024
+      FLOPs: 17581215744
+      Parameters: 86566120
+      Training Data: ImageNet-1k
+    In Collection: MaskFeat
+    Results:
+      - Task: Image Classification
+        Dataset: ImageNet-1k
+        Metrics:
+          Top 1 Accuracy: 83.0
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.pth
+    Config: configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py
+  - Name: vit-base-p16_mff-800e-pre_8xb128-coslr-100e_in1k
+    Metadata:
+      Epochs: 100
+      Batch Size: 1024
+      FLOPs: 17581215744
+      Parameters: 86566120
+      Training Data: ImageNet-1k
+    In Collection: MFF
+    Results:
+      - Task: Image Classification
+        Dataset: ImageNet-1k
+        Metrics:
+          Top 1 Accuracy: 83.7
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb128-coslr-100e/vit-base-p16_8xb128-coslr-100e_20230802-6780e47d.pth
+    Config: configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py
+  - Name: vit-base-p16_mff-300e-pre_8xb2048-linear-coslr-90e_in1k
+    Metadata:
+      Epochs: 90
+      Batch Size: 16384
+      FLOPs: 17581215744
+      Parameters: 86566120
+      Training Data: ImageNet-1k
+    In Collection: MFF
+    Results:
+      - Task: Image Classification
+        Dataset: ImageNet-1k
+        Metrics:
+          Top 1 Accuracy: 64.2
+    Weights:
+    Config: configs/mff/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py
+  - Name: vit-base-p16_mff-800e-pre_8xb2048-linear-coslr-90e_in1k
+    Metadata:
+      Epochs: 90
+      Batch Size: 16384
+      FLOPs: 17581215744
+      Parameters: 86566120
+      Training Data: ImageNet-1k
+    In Collection: MFF
+    Results:
+      - Task: Image Classification
+        Dataset: ImageNet-1k
+        Metrics:
+          Top 1 Accuracy: 68.3
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.pth
+    Config: configs/mff/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py
diff --git a/configs/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py b/configs/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py
new file mode 100644
index 00000000..f9fc5219
--- /dev/null
+++ b/configs/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py
@@ -0,0 +1,24 @@
+_base_ = '../mae/mae_vit-base-p16_8xb512-amp-coslr-300e_in1k.py'
+
+randomness = dict(seed=2, diff_rank_seed=True)
+
+# dataset config
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='ToPIL', to_rgb=True),
+    dict(type='torchvision/Resize', size=224),
+    dict(
+        type='torchvision/RandomCrop',
+        size=224,
+        padding=4,
+        padding_mode='reflect'),
+    dict(type='torchvision/RandomHorizontalFlip', p=0.5),
+    dict(type='ToNumpy', to_bgr=True),
+    dict(type='PackInputs')
+]
+
+train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
+
+# model config
+model = dict(
+    type='MFF', backbone=dict(type='MFFViT', out_indices=[0, 2, 4, 6, 8, 11]))
diff --git a/configs/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k.py b/configs/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k.py
new file mode 100644
index 00000000..d8976b22
--- /dev/null
+++ b/configs/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k.py
@@ -0,0 +1,24 @@
+_base_ = '../mae/mae_vit-base-p16_8xb512-amp-coslr-800e_in1k.py'
+
+randomness = dict(seed=2, diff_rank_seed=True)
+
+# dataset config
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='ToPIL', to_rgb=True),
+    dict(type='torchvision/Resize', size=224),
+    dict(
+        type='torchvision/RandomCrop',
+        size=224,
+        padding=4,
+        padding_mode='reflect'),
+    dict(type='torchvision/RandomHorizontalFlip', p=0.5),
+    dict(type='ToNumpy', to_bgr=True),
+    dict(type='PackInputs')
+]
+
+train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
+
+# model config
+model = dict(
+    type='MFF', backbone=dict(type='MFFViT', out_indices=[0, 2, 4, 6, 8, 11]))
diff --git a/mmpretrain/datasets/__init__.py b/mmpretrain/datasets/__init__.py
index 657ad92f..29753d70 100644
--- a/mmpretrain/datasets/__init__.py
+++ b/mmpretrain/datasets/__init__.py
@@ -41,8 +41,8 @@ if WITH_MULTIMODAL:
     from .flickr30k_caption import Flickr30kCaption
     from .flickr30k_retrieval import Flickr30kRetrieval
     from .gqa_dataset import GQA
-    from .infographic_vqa import InfographicVQA
     from .iconqa import IconQA
+    from .infographic_vqa import InfographicVQA
     from .nocaps import NoCaps
     from .ocr_vqa import OCRVQA
     from .refcoco import RefCOCO
diff --git a/mmpretrain/datasets/transforms/__init__.py b/mmpretrain/datasets/transforms/__init__.py
index 8cda99d5..617503f2 100644
--- a/mmpretrain/datasets/transforms/__init__.py
+++ b/mmpretrain/datasets/transforms/__init__.py
@@ -12,8 +12,9 @@ from .formatting import (Collect, NumpyToPIL, PackInputs, PackMultiTaskInputs,
                          PILToNumpy, Transpose)
 from .processing import (Albumentations, BEiTMaskGenerator, CleanCaption,
                          ColorJitter, EfficientNetCenterCrop,
-                         EfficientNetRandomCrop, Lighting, RandomCrop,
-                         RandomErasing, RandomResizedCrop,
+                         EfficientNetRandomCrop, Lighting,
+                         MAERandomResizedCrop, RandomCrop, RandomErasing,
+                         RandomResizedCrop,
                          RandomResizedCropAndInterpolationWithTwoPic,
                          RandomTranslatePad, ResizeEdge, SimMIMMaskGenerator)
 from .utils import get_transform_idx, remove_transform
@@ -36,5 +37,5 @@ __all__ = [
     'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView',
     'ApplyToList', 'CleanCaption', 'RandomTranslatePad',
     'RandomResizedCropAndInterpolationWithTwoPic', 'get_transform_idx',
-    'remove_transform'
+    'remove_transform', 'MAERandomResizedCrop'
 ]
diff --git a/mmpretrain/datasets/transforms/processing.py b/mmpretrain/datasets/transforms/processing.py
index dfd3ed6b..4c640f6b 100644
--- a/mmpretrain/datasets/transforms/processing.py
+++ b/mmpretrain/datasets/transforms/processing.py
@@ -11,9 +11,13 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
 import mmcv
 import mmengine
 import numpy as np
+import torch
 import torchvision
+import torchvision.transforms.functional as F
 from mmcv.transforms import BaseTransform
 from mmcv.transforms.utils import cache_randomness
+from PIL import Image
+from torchvision import transforms
 from torchvision.transforms.transforms import InterpolationMode
 
 from mmpretrain.registry import TRANSFORMS
@@ -1740,3 +1744,52 @@ class RandomTranslatePad(BaseTransform):
                 results['gt_bboxes'][i] = box
 
         return results
+
+
+@TRANSFORMS.register_module()
+class MAERandomResizedCrop(transforms.RandomResizedCrop):
+    """RandomResizedCrop for matching TF/TPU implementation: no for-loop is
+    used.
+
+    This may lead to results different with torchvision's version.
+    Following BYOL's TF code:
+    https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 # noqa: E501
+    """
+
+    @staticmethod
+    def get_params(img: Image.Image, scale: tuple, ratio: tuple) -> Tuple:
+        width, height = img.size
+        area = height * width
+
+        target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
+        log_ratio = torch.log(torch.tensor(ratio))
+        aspect_ratio = torch.exp(
+            torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
+
+        w = int(round(math.sqrt(target_area * aspect_ratio)))
+        h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+        w = min(w, width)
+        h = min(h, height)
+
+        i = torch.randint(0, height - h + 1, size=(1, )).item()
+        j = torch.randint(0, width - w + 1, size=(1, )).item()
+
+        return i, j, h, w
+
+    def forward(self, results: dict) -> dict:
+        """The forward function of MAERandomResizedCrop.
+
+        Args:
+            results (dict): The results dict contains the image and all these
+                information related to the image.
+
+        Returns:
+            dict: The results dict contains the cropped image and all these
+            information related to the image.
+        """
+        img = results['img']
+        i, j, h, w = self.get_params(img, self.scale, self.ratio)
+        img = F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
+        results['img'] = img
+        return results
diff --git a/mmpretrain/models/backbones/vision_transformer.py b/mmpretrain/models/backbones/vision_transformer.py
index 0e9efa34..a54053c2 100644
--- a/mmpretrain/models/backbones/vision_transformer.py
+++ b/mmpretrain/models/backbones/vision_transformer.py
@@ -436,7 +436,7 @@ class VisionTransformer(BaseBackbone):
         for param in self.pre_norm.parameters():
             param.requires_grad = False
         # freeze cls_token
-        if self.cls_token:
+        if self.cls_token is not None:
             self.cls_token.requires_grad = False
         # freeze layers
         for i in range(1, self.frozen_stages + 1):
diff --git a/mmpretrain/models/selfsup/__init__.py b/mmpretrain/models/selfsup/__init__.py
index 1052dedc..08c1ed59 100644
--- a/mmpretrain/models/selfsup/__init__.py
+++ b/mmpretrain/models/selfsup/__init__.py
@@ -9,6 +9,7 @@ from .eva import EVA
 from .itpn import iTPN, iTPNHiViT
 from .mae import MAE, MAEHiViT, MAEViT
 from .maskfeat import HOGGenerator, MaskFeat, MaskFeatViT
+from .mff import MFF, MFFViT
 from .milan import MILAN, CLIPGenerator, MILANViT
 from .mixmim import MixMIM, MixMIMPretrainTransformer
 from .moco import MoCo
@@ -53,4 +54,6 @@ __all__ = [
     'BarlowTwins',
     'SwAV',
     'SparK',
+    'MFF',
+    'MFFViT',
 ]
diff --git a/mmpretrain/models/selfsup/mff.py b/mmpretrain/models/selfsup/mff.py
new file mode 100644
index 00000000..26850580
--- /dev/null
+++ b/mmpretrain/models/selfsup/mff.py
@@ -0,0 +1,194 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List, Optional, Sequence, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+
+from mmpretrain.models.selfsup.mae import MAE, MAEViT
+from mmpretrain.registry import MODELS
+from mmpretrain.structures import DataSample
+
+
+@MODELS.register_module()
+class MFFViT(MAEViT):
+    """Vision Transformer for MFF Pretraining.
+
+    This class inherits all these functionalities from ``MAEViT``, and
+    add multi-level feature fusion to it. For more details, you can
+    refer to `Improving Pixel-based MIM by Reducing Wasted Modeling
+    Capability`.
+
+    Args:
+        arch (str | dict): Vision Transformer architecture
+            Default: 'b'
+        img_size (int | tuple): Input image size
+        patch_size (int | tuple): The patch size
+        out_indices (Sequence | int): Output from which stages.
+            Defaults to -1, means the last stage.
+        drop_rate (float): Probability of an element to be zeroed.
+            Defaults to 0.
+        drop_path_rate (float): stochastic depth rate. Defaults to 0.
+        norm_cfg (dict): Config dict for normalization layer.
+            Defaults to ``dict(type='LN')``.
+        final_norm (bool): Whether to add a additional layer to normalize
+            final feature map. Defaults to True.
+        out_type (str): The type of output features. Please choose from
+
+            - ``"cls_token"``: The class token tensor with shape (B, C).
+            - ``"featmap"``: The feature map tensor from the patch tokens
+              with shape (B, C, H, W).
+            - ``"avg_featmap"``: The global averaged feature map tensor
+              with shape (B, C).
+            - ``"raw"``: The raw feature tensor includes patch tokens and
+              class tokens with shape (B, L, C).
+
+            It only works without input mask. Defaults to ``"avg_featmap"``.
+        interpolate_mode (str): Select the interpolate mode for position
+            embeding vector resize. Defaults to "bicubic".
+        patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
+        layer_cfgs (Sequence | dict): Configs of each transformer layer in
+            encoder. Defaults to an empty dict.
+        mask_ratio (bool): The ratio of total number of patches to be masked.
+            Defaults to 0.75.
+        init_cfg (Union[List[dict], dict], optional): Initialization config
+            dict. Defaults to None.
+    """
+
+    def __init__(self,
+                 arch: Union[str, dict] = 'b',
+                 img_size: int = 224,
+                 patch_size: int = 16,
+                 out_indices: Union[Sequence, int] = -1,
+                 drop_rate: float = 0,
+                 drop_path_rate: float = 0,
+                 norm_cfg: dict = dict(type='LN', eps=1e-6),
+                 final_norm: bool = True,
+                 out_type: str = 'raw',
+                 interpolate_mode: str = 'bicubic',
+                 patch_cfg: dict = dict(),
+                 layer_cfgs: dict = dict(),
+                 mask_ratio: float = 0.75,
+                 init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
+        super().__init__(
+            arch=arch,
+            img_size=img_size,
+            patch_size=patch_size,
+            out_indices=out_indices,
+            drop_rate=drop_rate,
+            drop_path_rate=drop_path_rate,
+            norm_cfg=norm_cfg,
+            final_norm=final_norm,
+            out_type=out_type,
+            interpolate_mode=interpolate_mode,
+            patch_cfg=patch_cfg,
+            layer_cfgs=layer_cfgs,
+            mask_ratio=mask_ratio,
+            init_cfg=init_cfg)
+        proj_layers = [
+            torch.nn.Linear(self.embed_dims, self.embed_dims)
+            for _ in range(len(self.out_indices) - 1)
+        ]
+        self.proj_layers = torch.nn.ModuleList(proj_layers)
+        self.proj_weights = torch.nn.Parameter(
+            torch.ones(len(self.out_indices)).view(-1, 1, 1, 1))
+        if len(self.out_indices) == 1:
+            self.proj_weights.requires_grad = False
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        mask: Optional[bool] = True
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Generate features for masked images.
+
+        The function supports two kind of forward behaviors. If the ``mask`` is
+        ``True``, the function will generate mask to masking some patches
+        randomly and get the hidden features for visible patches, which means
+        the function will be executed as masked imagemodeling pre-training;
+        if the ``mask`` is ``None`` or ``False``, the forward function will
+        call ``super().forward()``, which extract features from images without
+        mask.
+
+
+        Args:
+            x (torch.Tensor): Input images, which is of shape B x C x H x W.
+            mask (bool, optional): To indicate whether the forward function
+                generating ``mask`` or not.
+
+        Returns:
+            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
+            mask and the ids to restore original image.
+
+            - ``x`` (torch.Tensor): hidden features, which is of shape
+              B x (L * mask_ratio) x C.
+            - ``mask`` (torch.Tensor): mask used to mask image.
+            - ``ids_restore`` (torch.Tensor): ids to restore original image.
+        """
+        if mask is None or False:
+            return super().forward(x)
+
+        else:
+            B = x.shape[0]
+            x = self.patch_embed(x)[0]
+            # add pos embed w/o cls token
+            x = x + self.pos_embed[:, 1:, :]
+
+            # masking: length -> length * mask_ratio
+            x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
+
+            # append cls token
+            cls_token = self.cls_token + self.pos_embed[:, :1, :]
+            cls_tokens = cls_token.expand(B, -1, -1)
+            x = torch.cat((cls_tokens, x), dim=1)
+
+            res = []
+            for i, layer in enumerate(self.layers):
+                x = layer(x)
+                if i in self.out_indices:
+                    if i != self.out_indices[-1]:
+                        proj_x = self.proj_layers[self.out_indices.index(i)](x)
+                    else:
+                        proj_x = x
+                    res.append(proj_x)
+            res = torch.stack(res)
+            proj_weights = F.softmax(self.proj_weights, dim=0)
+            res = res * proj_weights
+            res = res.sum(dim=0)
+
+            # Use final norm
+            x = self.norm1(res)
+            return (x, mask, ids_restore, proj_weights.view(-1))
+
+
+@MODELS.register_module()
+class MFF(MAE):
+    """MFF.
+
+    Implementation of `Improving Pixel-based MIM by Reducing Wasted Modeling
+    Capability`.
+    """
+
+    def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
+             **kwargs) -> Dict[str, torch.Tensor]:
+        """The forward function in training.
+
+        Args:
+            inputs (torch.Tensor): The input images.
+            data_samples (List[DataSample]): All elements required
+                during the forward function.
+
+        Returns:
+            Dict[str, torch.Tensor]: A dictionary of loss components.
+        """
+        # ids_restore: the same as that in original repo, which is used
+        # to recover the original order of tokens in decoder.
+        latent, mask, ids_restore, weights = self.backbone(inputs)
+        pred = self.neck(latent, ids_restore)
+        loss = self.head.loss(pred, inputs, mask)
+        weight_params = {
+            f'weight_{i}': weights[i]
+            for i in range(weights.size(0))
+        }
+        losses = dict(loss=loss)
+        losses.update(weight_params)
+        return losses
diff --git a/model-index.yml b/model-index.yml
index 3fb3d045..1bd92853 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -82,3 +82,4 @@ Import:
   - configs/minigpt4/metafile.yml
   - configs/llava/metafile.yml
   - configs/otter/metafile.yml
+  - configs/mff/metafile.yml
diff --git a/tests/test_models/test_selfsup/test_mff.py b/tests/test_models/test_selfsup/test_mff.py
new file mode 100644
index 00000000..3ad0295f
--- /dev/null
+++ b/tests/test_models/test_selfsup/test_mff.py
@@ -0,0 +1,63 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import platform
+
+import pytest
+import torch
+
+from mmpretrain.models import MFF, MFFViT
+from mmpretrain.structures import DataSample
+
+
+@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
+def test_mae_vit():
+    backbone = dict(
+        arch='b', patch_size=16, mask_ratio=0.75, out_indices=[1, 11])
+    mae_backbone = MFFViT(**backbone)
+    mae_backbone.init_weights()
+    fake_inputs = torch.randn((2, 3, 224, 224))
+
+    # test with mask
+    fake_outputs = mae_backbone(fake_inputs)[0]
+    assert list(fake_outputs.shape) == [2, 50, 768]
+
+
+@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
+def test_mae():
+    data_preprocessor = {
+        'mean': [0.5, 0.5, 0.5],
+        'std': [0.5, 0.5, 0.5],
+        'to_rgb': True
+    }
+    backbone = dict(
+        type='MFFViT',
+        arch='b',
+        patch_size=16,
+        mask_ratio=0.75,
+        out_indices=[1, 11])
+    neck = dict(
+        type='MAEPretrainDecoder',
+        patch_size=16,
+        in_chans=3,
+        embed_dim=768,
+        decoder_embed_dim=512,
+        decoder_depth=8,
+        decoder_num_heads=16,
+        mlp_ratio=4.,
+    )
+    loss = dict(type='PixelReconstructionLoss', criterion='L2')
+    head = dict(
+        type='MAEPretrainHead', norm_pix=False, patch_size=16, loss=loss)
+
+    alg = MFF(
+        backbone=backbone,
+        neck=neck,
+        head=head,
+        data_preprocessor=data_preprocessor)
+
+    fake_data = {
+        'inputs': torch.randn((2, 3, 224, 224)),
+        'data_samples': [DataSample() for _ in range(2)]
+    }
+    fake_inputs = alg.data_preprocessor(fake_data)
+    fake_outputs = alg(**fake_inputs, mode='loss')
+    assert isinstance(fake_outputs['loss'].item(), float)