From 034919d0326da74e601a14bb7e3b6074e0d14fb9 Mon Sep 17 00:00:00 2001
From: zzc98 <40905160+zzc98@users.noreply.github.com>
Date: Sat, 6 May 2023 19:28:31 +0800
Subject: [PATCH] [Feature] add eva02 backbone (#1450)

* [CI] Add test mim CI. (#879)

* [CI] Add test mim CI. (#879)

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update ci

* rebase

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update readme and configs

* update readme and configs

* refactore eva02

* [CI] Add test mim CI. (#879)

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update ci

* rebase

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update readme and configs

* refactore eva02

* update readme and metafile

* update readme and metafile

* update readme and metafile

* update

* rename eva02

* rename eva02

* fix uts

* rename configs

---------

Co-authored-by: Ma Zerun <mzr1996@163.com>
Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
---
 .../_base_/datasets/imagenet_bs16_eva_448.py  |  62 ++++
 configs/eva02/README.md                       | 109 ++++++
 configs/eva02/eva02-base-p14_headless.py      |  21 ++
 configs/eva02/eva02-base-p14_in1k.py          |  32 ++
 configs/eva02/eva02-large-p14_headless.py     |  21 ++
 configs/eva02/eva02-large-p14_in1k.py         |  32 ++
 configs/eva02/eva02-small-p14_headless.py     |  20 +
 configs/eva02/eva02-small-p14_in1k.py         |  31 ++
 configs/eva02/eva02-tiny-p14_headless.py      |  20 +
 configs/eva02/eva02-tiny-p14_in1k.py          |  31 ++
 configs/eva02/metafile.yml                    | 199 ++++++++++
 docs/en/api/models.rst                        |   1 +
 mmpretrain/models/backbones/__init__.py       |   2 +
 mmpretrain/models/backbones/vit_eva02.py      | 350 ++++++++++++++++++
 mmpretrain/models/utils/__init__.py           |   3 +-
 mmpretrain/models/utils/position_encoding.py  |  75 ++++
 model-index.yml                               |   1 +
 .../test_models/test_backbones/test_eva02.py  | 143 +++++++
 .../test_utils/test_position_encoding.py      |  13 +-
 tools/model_converters/eva02_to_mmpretrain.py | 153 ++++++++
 20 files changed, 1317 insertions(+), 2 deletions(-)
 create mode 100644 configs/_base_/datasets/imagenet_bs16_eva_448.py
 create mode 100644 configs/eva02/README.md
 create mode 100644 configs/eva02/eva02-base-p14_headless.py
 create mode 100644 configs/eva02/eva02-base-p14_in1k.py
 create mode 100644 configs/eva02/eva02-large-p14_headless.py
 create mode 100644 configs/eva02/eva02-large-p14_in1k.py
 create mode 100644 configs/eva02/eva02-small-p14_headless.py
 create mode 100644 configs/eva02/eva02-small-p14_in1k.py
 create mode 100644 configs/eva02/eva02-tiny-p14_headless.py
 create mode 100644 configs/eva02/eva02-tiny-p14_in1k.py
 create mode 100644 configs/eva02/metafile.yml
 create mode 100644 mmpretrain/models/backbones/vit_eva02.py
 create mode 100644 tests/test_models/test_backbones/test_eva02.py
 create mode 100644 tools/model_converters/eva02_to_mmpretrain.py

diff --git a/configs/_base_/datasets/imagenet_bs16_eva_448.py b/configs/_base_/datasets/imagenet_bs16_eva_448.py
new file mode 100644
index 00000000..b90bba14
--- /dev/null
+++ b/configs/_base_/datasets/imagenet_bs16_eva_448.py
@@ -0,0 +1,62 @@
+# dataset settings
+dataset_type = 'ImageNet'
+data_preprocessor = dict(
+    num_classes=1000,
+    # RGB format normalization parameters
+    mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
+    std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
+    # convert image from BGR to RGB
+    to_rgb=True,
+)
+
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='RandomResizedCrop',
+        scale=448,
+        backend='pillow',
+        interpolation='bicubic'),
+    dict(type='RandomFlip', prob=0.5, direction='horizontal'),
+    dict(type='PackInputs'),
+]
+
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='ResizeEdge',
+        scale=448,
+        edge='short',
+        backend='pillow',
+        interpolation='bicubic'),
+    dict(type='CenterCrop', crop_size=448),
+    dict(type='PackInputs'),
+]
+
+train_dataloader = dict(
+    batch_size=16,
+    num_workers=5,
+    dataset=dict(
+        type=dataset_type,
+        data_root='data/imagenet',
+        ann_file='meta/train.txt',
+        data_prefix='train',
+        pipeline=train_pipeline),
+    sampler=dict(type='DefaultSampler', shuffle=True),
+)
+
+val_dataloader = dict(
+    batch_size=8,
+    num_workers=5,
+    dataset=dict(
+        type=dataset_type,
+        data_root='data/imagenet',
+        ann_file='meta/val.txt',
+        data_prefix='val',
+        pipeline=test_pipeline),
+    sampler=dict(type='DefaultSampler', shuffle=False),
+)
+val_evaluator = dict(type='Accuracy', topk=(1, 5))
+
+# If you want standard test, please manually configure the test dataset
+test_dataloader = val_dataloader
+test_evaluator = val_evaluator
diff --git a/configs/eva02/README.md b/configs/eva02/README.md
new file mode 100644
index 00000000..bf0cea78
--- /dev/null
+++ b/configs/eva02/README.md
@@ -0,0 +1,109 @@
+# EVA-02
+
+> [EVA-02: A Visual Representation for Neon Genesis](https://arxiv.org/abs/2303.11331)
+
+<!-- [ALGORITHM] -->
+
+## Abstract
+
+We launch EVA-02, a next-generation Transformer-based visual representation pre-trained to reconstruct strong and robust language-aligned vision features via masked image modeling. With an updated plain Transformer architecture as well as extensive pre-training from an open & accessible giant CLIP vision encoder, EVA-02 demonstrates superior performance compared to prior state-of-the-art approaches across various representative vision tasks, while utilizing significantly fewer parameters and compute budgets. Notably, using exclusively publicly accessible training data, EVA-02 with only 304M parameters achieves a phenomenal 90.0 fine-tuning top-1 accuracy on ImageNet-1K val set.  Additionally, our EVA-02-CLIP can reach up to 80.4 zero-shot top-1 on ImageNet-1K, outperforming the previous largest & best open-sourced CLIP with only ~1/6 parameters and ~1/6 image-text training data. We offer four EVA-02 variants in various model sizes, ranging from 6M to 304M parameters, all with impressive performance. To facilitate open accessand open research, we release the complete suite of EVA-02 to the community.
+
+<div align=center>
+<img src="https://user-images.githubusercontent.com/40905160/229037980-b83dceb5-41d6-406c-a20b-63b83c80136d.png" width="70%" alt="TrV builds upon the original plain ViT architecture and includes several enhancements: SwinGLU FFN, sub-LN, 2D RoPE, and JAX weight initialization. To keep the parameter & FLOPs consistent with the baseline, the FFN hidden dim of SwiGLU is 2/3× of the typical MLP counterpart."/>
+</div>
+
+## How to use it?
+
+<!-- [TABS-BEGIN] -->
+
+**Predict image**
+
+```python
+from mmpretrain import inference_model
+
+predict = inference_model('vit-tiny-p14_eva02-in21k-pre_3rdparty_in1k-336px', 'demo/bird.JPEG')
+print(predict['pred_class'])
+print(predict['pred_score'])
+```
+
+**Use the model**
+
+```python
+import torch
+from mmpretrain import get_model
+
+model = get_model('vit-tiny-p14_eva02-in21k-pre_3rdparty_in1k-336px', pretrained=True)
+inputs = torch.rand(1, 3, 336, 336)
+out = model(inputs)
+print(type(out))
+# To extract features.
+feats = model.extract_feat(inputs)
+print(type(feats))
+```
+
+**Train/Test Command**
+
+Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
+
+Train:
+
+```shell
+python tools/train.py configs/eva02/eva02-tiny-p14_in1k.py
+```
+
+Test:
+
+```shell
+python tools/test.py configs/eva02/eva02-tiny-p14_in1k.py /path/to/eva02-tiny-p14_in1k.pth
+```
+
+<!-- [TABS-END] -->
+
+## Models and results
+
+### Pretrained models
+
+| Model                             | Params (M) | Flops (G) |                Config                 |                                                   Download                                                    |
+| :-------------------------------- | :--------: | :-------: | :-----------------------------------: | :-----------------------------------------------------------------------------------------------------------: |
+| `vit-tiny-p14_eva02-pre_in21k`\*  |    5.50    |   1.70    | [config](eva02-tiny-p14_headless.py)  | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-tiny-p14_pre_in21k_20230505-d703e7b1.pth)  |
+| `vit-small-p14_eva02-pre_in21k`\* |   21.62    |   6.14    | [config](eva02-small-p14_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-small-p14_pre_in21k_20230505-3175f463.pth) |
+| `vit-base-p14_eva02-pre_in21k`\*  |   85.77    |   23.22   | [config](eva02-base-p14_headless.py)  | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_pre_in21k_20230505-2f2d4d3c.pth)  |
+| `vit-large-p14_eva02-pre_in21k`\* |   303.29   |   81.15   | [config](eva02-large-p14_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_pre_in21k_20230505-9072de5d.pth) |
+| `vit-large-p14_eva02-pre_m38m`\*  |   303.29   |   81.15   | [config](eva02-large-p14_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_pre_m38m_20230505-b8a1a261.pth)  |
+
+- The input size / patch size of MIM pre-trained EVA-02 is `224x224` / `14x14`.
+
+*Models with * are converted from the [official repo](https://github.com/baaivision/EVA).*
+
+### Image Classification on ImageNet-1k
+
+#### (*w/o* IN-21K intermediate fine-tuning)
+
+| Model                                                 |      Pretrain      | Params (M) | Flops (G) | Top-1 (%) | Top-5 (%) |               Config                |                         Download                          |
+| :---------------------------------------------------- | :----------------: | :--------: | :-------: | :-------: | :-------: | :---------------------------------: | :-------------------------------------------------------: |
+| `vit-tiny-p14_eva02-in21k-pre_3rdparty_in1k-336px`\*  | EVA02 ImageNet-21k |    5.76    |   4.68    |   80.69   |   95.54   | [config](./eva02-tiny-p14_in1k.py)  | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-tiny-p14_in21k-pre_3rdparty_in1k-336px_20230505-a4e8708a.pth) |
+| `vit-small-p14_eva02-in21k-pre_3rdparty_in1k-336px`\* | EVA02 ImageNet-21k |   22.13    |   15.48   |   85.78   |   97.60   | [config](./eva02-small-p14_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-small-p14_in21k-pre_3rdparty_in1k-336px_20230505-9c5b0e85.pth) |
+| `vit-base-p14_eva02-in21k-pre_3rdparty_in1k-448px`\*  | EVA02 ImageNet-21k |   87.13    |  107.11   |   88.29   |   98.53   | [config](./eva02-base-p14_in1k.py)  | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_in21k-pre_3rdparty_in1k-448px_20230505-8ad211c5.pth) |
+
+*Models with * are converted from the  [official repo](https://github.com/baaivision/EVA/tree/master/EVA-02). The config files of these models are only for inference. We haven't reprodcue the training results.*
+
+#### (*w* IN-21K intermediate fine-tuning)
+
+| Model                                                 |      Pretrain      | Params (M) | Flops (G) | Top-1 (%) | Top-5 (%) |               Config                |                         Download                          |
+| :---------------------------------------------------- | :----------------: | :--------: | :-------: | :-------: | :-------: | :---------------------------------: | :-------------------------------------------------------: |
+| `vit-base-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px`\* | EVA02 ImageNet-21k |   87.13    |  107.11   |   88.47   |   98.62   | [config](./eva02-base-p14_in1k.py)  | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_in21k-pre_in21k-medft_3rdparty_in1k-448px_20230505-5cd4d87f.pth) |
+| `vit-large-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px`\* | EVA02 ImageNet-21k |   305.08   |  362.33   |   89.65   |   98.95   | [config](./eva02-large-p14_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_in21k-pre_in21k-medft_3rdparty_in1k-448px_20230505-926d1599.pth) |
+| `vit-large-p14_eva02_m38m-pre_in21k-medft_3rdparty_in1k-448px`\* |  EVA02 Merged-38M  |   305.10   |  362.33   |   89.83   |   99.00   | [config](./eva02-large-p14_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_m38m-pre_in21k-medft_3rdparty_in1k-448px_20230505-150dc5ed.pth) |
+
+*Models with * are converted from the  [official repo](https://github.com/baaivision/EVA/tree/master/EVA-02). The config files of these models are only for inference. We haven't reprodcue the training results.*
+
+## Citation
+
+```bibtex
+@article{EVA-02,
+  title={EVA-02: A Visual Representation for Neon Genesis},
+  author={Yuxin Fang and Quan Sun and Xinggang Wang and Tiejun Huang and Xinlong Wang and Yue Cao},
+  journal={arXiv preprint arXiv:2303.11331},
+  year={2023}
+}
+```
diff --git a/configs/eva02/eva02-base-p14_headless.py b/configs/eva02/eva02-base-p14_headless.py
new file mode 100644
index 00000000..27aa8f8a
--- /dev/null
+++ b/configs/eva02/eva02-base-p14_headless.py
@@ -0,0 +1,21 @@
+model = dict(
+    type='ImageClassifier',
+    backbone=dict(
+        type='ViTEVA02',
+        arch='b',
+        img_size=224,
+        patch_size=14,
+        sub_ln=True,
+        final_norm=False,
+        out_type='avg_featmap'),
+    neck=None,
+    head=None,
+)
+
+data_preprocessor = dict(
+    # RGB format normalization parameters
+    mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
+    std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
+    # convert image from BGR to RGB
+    to_rgb=True,
+)
diff --git a/configs/eva02/eva02-base-p14_in1k.py b/configs/eva02/eva02-base-p14_in1k.py
new file mode 100644
index 00000000..c8400d38
--- /dev/null
+++ b/configs/eva02/eva02-base-p14_in1k.py
@@ -0,0 +1,32 @@
+_base_ = [
+    '../_base_/datasets/imagenet_bs16_eva_448.py',
+    '../_base_/schedules/imagenet_bs2048_AdamW.py',
+    '../_base_/default_runtime.py'
+]
+
+model = dict(
+    type='ImageClassifier',
+    backbone=dict(
+        type='ViTEVA02',
+        arch='b',
+        img_size=448,
+        patch_size=14,
+        sub_ln=True,
+        final_norm=False,
+        out_type='avg_featmap'),
+    neck=None,
+    head=dict(
+        type='LinearClsHead',
+        num_classes=1000,
+        in_channels=768,
+        loss=dict(
+            type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+    ),
+    init_cfg=[
+        dict(type='TruncNormal', layer='Linear', std=.02),
+        dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
+    ],
+    train_cfg=dict(augments=[
+        dict(type='Mixup', alpha=0.8),
+        dict(type='CutMix', alpha=1.0)
+    ]))
diff --git a/configs/eva02/eva02-large-p14_headless.py b/configs/eva02/eva02-large-p14_headless.py
new file mode 100644
index 00000000..e101ac97
--- /dev/null
+++ b/configs/eva02/eva02-large-p14_headless.py
@@ -0,0 +1,21 @@
+model = dict(
+    type='ImageClassifier',
+    backbone=dict(
+        type='ViTEVA02',
+        arch='l',
+        img_size=224,
+        patch_size=14,
+        sub_ln=True,
+        final_norm=False,
+        out_type='avg_featmap'),
+    neck=None,
+    head=None,
+)
+
+data_preprocessor = dict(
+    # RGB format normalization parameters
+    mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
+    std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
+    # convert image from BGR to RGB
+    to_rgb=True,
+)
diff --git a/configs/eva02/eva02-large-p14_in1k.py b/configs/eva02/eva02-large-p14_in1k.py
new file mode 100644
index 00000000..91a42776
--- /dev/null
+++ b/configs/eva02/eva02-large-p14_in1k.py
@@ -0,0 +1,32 @@
+_base_ = [
+    '../_base_/datasets/imagenet_bs16_eva_448.py',
+    '../_base_/schedules/imagenet_bs2048_AdamW.py',
+    '../_base_/default_runtime.py'
+]
+
+model = dict(
+    type='ImageClassifier',
+    backbone=dict(
+        type='ViTEVA02',
+        arch='l',
+        img_size=448,
+        patch_size=14,
+        sub_ln=True,
+        final_norm=False,
+        out_type='avg_featmap'),
+    neck=None,
+    head=dict(
+        type='LinearClsHead',
+        num_classes=1000,
+        in_channels=1024,
+        loss=dict(
+            type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+    ),
+    init_cfg=[
+        dict(type='TruncNormal', layer='Linear', std=.02),
+        dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
+    ],
+    train_cfg=dict(augments=[
+        dict(type='Mixup', alpha=0.8),
+        dict(type='CutMix', alpha=1.0)
+    ]))
diff --git a/configs/eva02/eva02-small-p14_headless.py b/configs/eva02/eva02-small-p14_headless.py
new file mode 100644
index 00000000..a9698193
--- /dev/null
+++ b/configs/eva02/eva02-small-p14_headless.py
@@ -0,0 +1,20 @@
+model = dict(
+    type='ImageClassifier',
+    backbone=dict(
+        type='ViTEVA02',
+        arch='s',
+        img_size=224,
+        patch_size=14,
+        final_norm=False,
+        out_type='avg_featmap'),
+    neck=None,
+    head=None,
+)
+
+data_preprocessor = dict(
+    # RGB format normalization parameters
+    mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
+    std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
+    # convert image from BGR to RGB
+    to_rgb=True,
+)
diff --git a/configs/eva02/eva02-small-p14_in1k.py b/configs/eva02/eva02-small-p14_in1k.py
new file mode 100644
index 00000000..4a16d924
--- /dev/null
+++ b/configs/eva02/eva02-small-p14_in1k.py
@@ -0,0 +1,31 @@
+_base_ = [
+    '../_base_/datasets/imagenet_bs16_eva_336.py',
+    '../_base_/schedules/imagenet_bs2048_AdamW.py',
+    '../_base_/default_runtime.py'
+]
+
+model = dict(
+    type='ImageClassifier',
+    backbone=dict(
+        type='ViTEVA02',
+        arch='s',
+        img_size=336,
+        patch_size=14,
+        final_norm=False,
+        out_type='avg_featmap'),
+    neck=None,
+    head=dict(
+        type='LinearClsHead',
+        num_classes=1000,
+        in_channels=384,
+        loss=dict(
+            type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+    ),
+    init_cfg=[
+        dict(type='TruncNormal', layer='Linear', std=.02),
+        dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
+    ],
+    train_cfg=dict(augments=[
+        dict(type='Mixup', alpha=0.8),
+        dict(type='CutMix', alpha=1.0)
+    ]))
diff --git a/configs/eva02/eva02-tiny-p14_headless.py b/configs/eva02/eva02-tiny-p14_headless.py
new file mode 100644
index 00000000..783d0ea2
--- /dev/null
+++ b/configs/eva02/eva02-tiny-p14_headless.py
@@ -0,0 +1,20 @@
+model = dict(
+    type='ImageClassifier',
+    backbone=dict(
+        type='ViTEVA02',
+        arch='t',
+        img_size=224,
+        patch_size=14,
+        final_norm=False,
+        out_type='avg_featmap'),
+    neck=None,
+    head=None,
+)
+
+data_preprocessor = dict(
+    # RGB format normalization parameters
+    mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
+    std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
+    # convert image from BGR to RGB
+    to_rgb=True,
+)
diff --git a/configs/eva02/eva02-tiny-p14_in1k.py b/configs/eva02/eva02-tiny-p14_in1k.py
new file mode 100644
index 00000000..84e68d7e
--- /dev/null
+++ b/configs/eva02/eva02-tiny-p14_in1k.py
@@ -0,0 +1,31 @@
+_base_ = [
+    '../_base_/datasets/imagenet_bs16_eva_336.py',
+    '../_base_/schedules/imagenet_bs2048_AdamW.py',
+    '../_base_/default_runtime.py'
+]
+
+model = dict(
+    type='ImageClassifier',
+    backbone=dict(
+        type='ViTEVA02',
+        arch='t',
+        img_size=336,
+        patch_size=14,
+        final_norm=False,
+        out_type='avg_featmap'),
+    neck=None,
+    head=dict(
+        type='LinearClsHead',
+        num_classes=1000,
+        in_channels=192,
+        loss=dict(
+            type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+    ),
+    init_cfg=[
+        dict(type='TruncNormal', layer='Linear', std=.02),
+        dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
+    ],
+    train_cfg=dict(augments=[
+        dict(type='Mixup', alpha=0.8),
+        dict(type='CutMix', alpha=1.0)
+    ]))
diff --git a/configs/eva02/metafile.yml b/configs/eva02/metafile.yml
new file mode 100644
index 00000000..80acf904
--- /dev/null
+++ b/configs/eva02/metafile.yml
@@ -0,0 +1,199 @@
+Collections:
+  - Name: EVA02
+    Metadata:
+      Architecture:
+        - Rotary Position Embedding
+        - Sub Layer Normalization
+        - SwiGLU
+    Paper:
+      Title: 'EVA-02: A Visual Representation for Neon Genesis'
+      URL: https://arxiv.org/abs/2303.11331
+    README: configs/eva02/README.md
+
+Models:
+  - Name: vit-tiny-p14_eva02-pre_in21k
+    Metadata:
+      FLOPs: 1703439360
+      Parameters: 5504064
+      Training Data:
+        - ImageNet-21k
+    In Collection: EVA02
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-tiny-p14_pre_in21k_20230505-d703e7b1.pth
+    Config: configs/eva02/eva02-tiny-p14_headless.py
+    Converted From:
+      Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/pt/eva02_Ti_pt_in21k_p14.pt
+      Code: https://github.com/baaivision/EVA/tree/master/EVA-02
+    Downstream:
+      - vit-tiny-p14_eva02-in21k-pre_3rdparty_in1k-336px
+  - Name: vit-tiny-p14_eva02-in21k-pre_3rdparty_in1k-336px
+    Metadata:
+      FLOPs: 4675416000
+      Parameters: 5758888
+      Training Data:
+        - ImageNet-21k
+        - ImageNet-1k
+    In Collection: EVA02
+    Results:
+      - Dataset: ImageNet-1k
+        Task: Image Classification
+        Metrics:
+          Top 1 Accuracy: 80.69
+          Top 5 Accuracy: 95.54
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-tiny-p14_in21k-pre_3rdparty_in1k-336px_20230505-a4e8708a.pth
+    Config: configs/eva02/eva02-tiny-p14_in1k.py
+    Converted From:
+      Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in1k/eva02_Ti_pt_in21k_ft_in1k_p14.pt
+      Code: https://github.com/baaivision/EVA/tree/master/EVA-02
+  - Name: vit-small-p14_eva02-pre_in21k
+    Metadata:
+      FLOPs: 6135404544
+      Parameters: 21624960
+      Training Data:
+        - ImageNet-21k
+    In Collection: EVA02
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-small-p14_pre_in21k_20230505-3175f463.pth
+    Config: configs/eva02/eva02-small-p14_headless.py
+    Converted From:
+      Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/pt/eva02_S_pt_in21k_p14.pt
+      Code: https://github.com/baaivision/EVA/tree/master/EVA-02
+    Downstream:
+      - vit-small-p14_eva02-in21k-pre_3rdparty_in1k-336px
+  - Name: vit-small-p14_eva02-in21k-pre_3rdparty_in1k-336px
+    Metadata:
+      FLOPs: 15476744064
+      Parameters: 22133608
+      Training Data:
+        - ImageNet-21k
+        - ImageNet-1k
+    In Collection: EVA02
+    Results:
+      - Dataset: ImageNet-1k
+        Task: Image Classification
+        Metrics:
+          Top 1 Accuracy: 85.78
+          Top 5 Accuracy: 97.60
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-small-p14_in21k-pre_3rdparty_in1k-336px_20230505-9c5b0e85.pth
+    Config: configs/eva02/eva02-small-p14_in1k.py
+    Converted From:
+      Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in1k/eva02_S_pt_in21k_ft_in1k_p14.pt
+      Code: https://github.com/baaivision/EVA/tree/master/EVA-02
+  - Name: vit-base-p14_eva02-pre_in21k
+    Metadata:
+      FLOPs: 23216492544
+      Parameters: 85766400
+      Training Data:
+        - ImageNet-21k
+    In Collection: EVA02
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_pre_in21k_20230505-2f2d4d3c.pth
+    Config: configs/eva02/eva02-base-p14_headless.py
+    Converted From:
+      Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/pt/eva02_B_pt_in21k_p14.pt
+      Code: https://github.com/baaivision/EVA/tree/master/EVA-02
+    Downstream:
+      - vit-base-p14_eva02-in21k-pre_3rdparty_in1k-448px
+      - vit-base-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px
+  - Name: vit-base-p14_eva02-in21k-pre_3rdparty_in1k-448px
+    Metadata:
+      FLOPs: 107105984256
+      Parameters: 87126760
+      Training Data:
+        - ImageNet-21k
+        - ImageNet-1k
+    In Collection: EVA02
+    Results:
+      - Dataset: ImageNet-1k
+        Task: Image Classification
+        Metrics:
+          Top 1 Accuracy: 88.29
+          Top 5 Accuracy: 98.53
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_in21k-pre_3rdparty_in1k-448px_20230505-8ad211c5.pth
+    Config: configs/eva02/eva02-base-p14_in1k.py
+    Converted From:
+      Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in1k/eva02_B_pt_in21k_ft_in1k_p14.pt
+      Code: https://github.com/baaivision/EVA/tree/master/EVA-02
+  - Name: vit-base-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px
+    Metadata:
+      FLOPs: 107105984256
+      Parameters: 87126760
+      Training Data:
+        - ImageNet-21k
+        - ImageNet-1k
+    In Collection: EVA02
+    Results:
+      - Dataset: ImageNet-1k
+        Task: Image Classification
+        Metrics:
+          Top 1 Accuracy: 88.47
+          Top 5 Accuracy: 98.62
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_in21k-pre_in21k-medft_3rdparty_in1k-448px_20230505-5cd4d87f.pth
+    Config: configs/eva02/eva02-base-p14_in1k.py
+    Converted From:
+      Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in21k/eva02_B_pt_in21k_medft_in21k_p14.pt
+      Code: https://github.com/baaivision/EVA/tree/master/EVA-02
+  - Name: vit-large-p14_eva02-pre_in21k
+    Metadata:
+      FLOPs: 81146703792
+      Parameters: 303291328
+      Training Data:
+        - ImageNet-21k
+    In Collection: EVA02
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_pre_in21k_20230505-9072de5d.pth
+    Config: configs/eva02/eva02-large-p14_headless.py
+    Converted From:
+      Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/pt/eva02_L_pt_in21k_p14.pt
+      Code: https://github.com/baaivision/EVA/tree/master/EVA-02
+    Downstream:
+      - vit-large-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px
+  - Name: vit-large-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px
+    Metadata:
+      FLOPs: 362333836208
+      Parameters: 305104808
+      Training Data:
+        - ImageNet-21k
+        - ImageNet-1k
+    In Collection: EVA02
+    Results:
+      - Dataset: ImageNet-1k
+        Task: Image Classification
+        Metrics:
+          Top 1 Accuracy: 89.65
+          Top 5 Accuracy: 98.95
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_in21k-pre_in21k-medft_3rdparty_in1k-448px_20230505-926d1599.pth
+    Config: configs/eva02/eva02-large-p14_in1k.py
+    Converted From:
+      Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in21k/eva02_L_pt_in21k_medft_in21k_p14.pt
+      Code: https://github.com/baaivision/EVA/tree/master/EVA-02
+  - Name: vit-large-p14_eva02-pre_m38m
+    Metadata:
+      FLOPs: 81146703792
+      Parameters: 303291328
+      Training Data:
+        - Merged-38M
+    In Collection: EVA02
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_pre_m38m_20230505-b8a1a261.pth
+    Config: configs/eva02/eva02-large-p14_headless.py
+    Converted From:
+      Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/pt/eva02_L_pt_m38m_p14.pt
+      Code: https://github.com/baaivision/EVA/tree/master/EVA-02
+    Downstream:
+      - vit-large-p14_eva02_m38m-pre_in21k-medft_3rdparty_in1k-448px
+  - Name: vit-large-p14_eva02_m38m-pre_in21k-medft_3rdparty_in1k-448px
+    Metadata:
+      FLOPs: 362333836208
+      Parameters: 305104808
+      Training Data:
+        - Merged-38M
+        - ImageNet-21k
+        - ImageNet-1k
+    In Collection: EVA02
+    Results:
+      - Dataset: ImageNet-1k
+        Task: Image Classification
+        Metrics:
+          Top 1 Accuracy: 89.83
+          Top 5 Accuracy: 99.00
+    Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_m38m-pre_in21k-medft_3rdparty_in1k-448px_20230505-150dc5ed.pth
+    Config: configs/eva02/eva02-large-p14_in1k.py
+    Converted From:
+      Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in21k/eva02_L_pt_m38m_medft_in21k_p14.pt
+      Code: https://github.com/baaivision/EVA/tree/master/EVA-02
diff --git a/docs/en/api/models.rst b/docs/en/api/models.rst
index 42757862..7b6d607a 100644
--- a/docs/en/api/models.rst
+++ b/docs/en/api/models.rst
@@ -189,6 +189,7 @@ Backbones
    VisionTransformer
    ViTSAM
    XCiT
+   ViTEVA02
 
 .. module:: mmpretrain.models.necks
 
diff --git a/mmpretrain/models/backbones/__init__.py b/mmpretrain/models/backbones/__init__.py
index ab77dd65..d9830f12 100644
--- a/mmpretrain/models/backbones/__init__.py
+++ b/mmpretrain/models/backbones/__init__.py
@@ -52,6 +52,7 @@ from .van import VAN
 from .vgg import VGG
 from .vig import PyramidVig, Vig
 from .vision_transformer import VisionTransformer
+from .vit_eva02 import ViTEVA02
 from .vit_sam import ViTSAM
 from .xcit import XCiT
 
@@ -118,4 +119,5 @@ __all__ = [
     'PyramidVig',
     'XCiT',
     'ViTSAM',
+    'ViTEVA02',
 ]
diff --git a/mmpretrain/models/backbones/vit_eva02.py b/mmpretrain/models/backbones/vit_eva02.py
new file mode 100644
index 00000000..20ec4b24
--- /dev/null
+++ b/mmpretrain/models/backbones/vit_eva02.py
@@ -0,0 +1,350 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.cnn.bricks.drop import build_dropout
+from mmengine.model import BaseModule, ModuleList
+
+from mmpretrain.registry import MODELS
+from ..utils import (RotaryEmbeddingFast, SwiGLUFFN, build_norm_layer,
+                     resize_pos_embed)
+from .vision_transformer import VisionTransformer
+
+
+class AttentionWithRoPE(BaseModule):
+    """Multi-head Attention Module with 2D sincos position embedding (RoPE).
+
+    Args:
+        embed_dims (int): The embedding dimension.
+        num_heads (int): Parallel attention heads.
+        attn_drop (float): Dropout rate of the dropout layer after the
+            attention calculation of query and key. Defaults to 0.
+        proj_drop (float): Dropout rate of the dropout layer after the
+            output projection. Defaults to 0.
+        qkv_bias (bool): If True, add a learnable bias to q and v. Note
+            that we follows the official implementation where ``k_bias``
+            is 0. Defaults to True.
+        qk_scale (float, optional): Override default qk scale of
+            ``head_dim ** -0.5`` if set. Defaults to None.
+        proj_bias (bool) If True, add a learnable bias to output projection.
+            Defaults to True.
+        rope (:obj:`torch.nn.Module`, optional): If it is an object of the
+            ``RotaryEmbedding``, the rotation of the token position will be
+            performed before the softmax. Defaults to None.
+        with_cls_token (bool): Whether concatenating class token into image
+            tokens as transformer input. Defaults to True.
+        init_cfg (dict, optional): The Config for initialization.
+            Defaults to None.
+    """
+
+    def __init__(self,
+                 embed_dims,
+                 num_heads,
+                 attn_drop=0.,
+                 proj_drop=0.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 proj_bias=True,
+                 rope=None,
+                 with_cls_token=True,
+                 init_cfg=None):
+        super(AttentionWithRoPE, self).__init__(init_cfg=init_cfg)
+
+        self.embed_dims = embed_dims
+        self.num_heads = num_heads
+        self.head_dims = embed_dims // num_heads
+        self.scale = qk_scale or self.head_dims**-0.5
+        self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
+
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        self.with_cls_token = with_cls_token
+
+        self.rope = rope
+
+    def forward(self, x, patch_resolution):
+        B, N, _ = x.shape
+
+        qkv = self.qkv(x)
+        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv.unbind(dim=0)
+
+        if self.rope:
+            if self.with_cls_token:
+                q_t = q[:, :, 1:, :]
+                ro_q_t = self.rope(q_t, patch_resolution)
+                q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
+
+                k_t = k[:, :, 1:, :] if self.with_cls_token else k
+                ro_k_t = self.rope(k_t, patch_resolution)
+                k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
+            else:
+                q = self.rope(q, patch_resolution)
+                k = self.rope(k, patch_resolution)
+
+        q = q * self.scale
+
+        attn = (q @ k.transpose(-2, -1))
+        attn = attn.softmax(dim=-1).type_as(x)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+
+        x = self.proj(x)
+        x = self.proj_drop(x)
+
+        return x
+
+
+class EVA02EndcoderLayer(BaseModule):
+    """Implements one encoder EVA02EndcoderLayer in EVA02.
+
+    Args:
+        embed_dims (int): The feature dimension
+        num_heads (int): Parallel attention heads
+        feedforward_channels (int): The hidden dimension of FFNs.
+        sub_ln (bool): Whether to add the sub layer normalization
+            in the attention module. Defaults to False.
+        attn_drop (float): Dropout rate of the dropout layer after the
+            attention calculation of query and key. Defaults to 0.
+        proj_drop (float): Dropout rate of the dropout layer after the
+            output projection. Defaults to 0.
+        qkv_bias (bool): enable bias for qkv if True. Defaults to True.
+        qk_scale (float, optional): Override default qk scale of
+            ``head_dim ** -0.5`` if set. Defaults to None.
+        proj_bias (bool): enable bias for projection in the attention module
+            if True. Defaults to True.
+        rope (:obj:`torch.nn.Module`, optional): RotaryEmbedding object
+            in the attention module. Defaults to None.
+        drop_rate (float): Dropout rate in the mlp module. Defaults to 0.
+        drop_path_rate (float): Stochastic depth rate. Defaults to 0.
+        norm_cfg (dict): Config dict for normalization layer.
+            Defaults to ``dict(type='LN')``.
+        init_cfg (dict, optional): Initialization config dict.
+            Defaults to None.
+    """
+
+    def __init__(self,
+                 embed_dims,
+                 num_heads,
+                 feedforward_channels,
+                 sub_ln=False,
+                 attn_drop=0.,
+                 proj_drop=0.,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 proj_bias=True,
+                 rope=None,
+                 with_cls_token=True,
+                 drop_rate=0.,
+                 drop_path_rate=0.,
+                 norm_cfg=dict(type='LN'),
+                 init_cfg=None):
+        super(EVA02EndcoderLayer, self).__init__(init_cfg=init_cfg)
+
+        self.norm1 = build_norm_layer(norm_cfg, embed_dims)
+
+        self.attn = AttentionWithRoPE(
+            embed_dims=embed_dims,
+            num_heads=num_heads,
+            attn_drop=attn_drop,
+            proj_drop=proj_drop,
+            qkv_bias=qkv_bias,
+            qk_scale=qk_scale,
+            proj_bias=proj_bias,
+            rope=rope,
+            with_cls_token=with_cls_token)
+
+        self.drop_path = build_dropout(
+            dict(type='DropPath', drop_prob=drop_path_rate))
+
+        self.norm2 = build_norm_layer(norm_cfg, embed_dims)
+
+        if drop_rate > 0:
+            dropout_layer = dict(type='Dropout', drop_prob=drop_rate)
+        else:
+            dropout_layer = None
+
+        if sub_ln:
+            ffn_norm = norm_cfg
+        else:
+            ffn_norm = None
+
+        self.mlp = SwiGLUFFN(
+            embed_dims=embed_dims,
+            feedforward_channels=feedforward_channels,
+            dropout_layer=dropout_layer,
+            norm_cfg=ffn_norm,
+            add_identity=False,
+        )
+
+    def forward(self, x, patch_resolution):
+        inputs = x
+        x = self.norm1(x)
+        x = self.attn(x, patch_resolution)
+        x = self.drop_path(x)
+        x = inputs + x
+
+        inputs = x
+        x = self.norm2(x)
+        x = self.mlp(x)
+        x = self.drop_path(x)
+        x = inputs + x
+
+        return x
+
+
+@MODELS.register_module()
+class ViTEVA02(VisionTransformer):
+    """EVA02 Vision Transformer.
+
+    A PyTorch implement of : `EVA-02: A Visual Representation for Neon Genesis
+    <https://arxiv.org/abs/2303.11331>`_
+
+    Args:
+        arch (str | dict): Vision Transformer architecture. If use string,
+            choose from 'tiny', 'small', 'base', 'large'. If use dict,
+            it should have below keys:
+
+            - **embed_dims** (int): The dimensions of embedding.
+            - **num_layers** (int): The number of transformer encoder layers.
+            - **num_heads** (int): The number of heads in attention modules.
+            - **mlp_ratio** (float): The ratio of the mlp module.
+
+            Defaults to 'tiny'.
+
+        sub_ln (bool): Whether to add the sub layer normalization in swiglu.
+            Defaults to False.
+        drop_rate (float): Probability of an element to be zeroed in the
+            mlp module. Defaults to 0.
+        attn_drop_rate (float): Probability of an element to be zeroed after
+            the softmax in the attention. Defaults to 0.
+        proj_drop_rate (float): Probability of an element to be zeroed after
+            projection in the attention. Defaults to 0.
+        drop_path_rate (float): stochastic depth rate. Defaults to 0.
+        qkv_bias (bool): Whether to add bias for qkv in attention modules.
+            Defaults to True.
+        norm_cfg (dict): Config dict for normalization layer.
+            Defaults to ``dict(type='LN')``.
+        with_cls_token (bool): Whether concatenating class token into image
+            tokens as transformer input. Defaults to True.
+        layer_cfgs (Sequence | dict): Configs of each transformer layer in
+            encoder. Defaults to an empty dict.
+        **kwargs(dict, optional): Other args for Vision Transformer.
+    """
+    arch_zoo = {
+        **dict.fromkeys(
+            ['t', 'ti', 'tiny'], {
+                'embed_dims': 192,
+                'num_layers': 12,
+                'num_heads': 3,
+                'feedforward_channels': int(192 * 4 * 2 / 3)
+            }),
+        **dict.fromkeys(
+            ['s', 'small'], {
+                'embed_dims': 384,
+                'num_layers': 12,
+                'num_heads': 6,
+                'feedforward_channels': int(384 * 4 * 2 / 3)
+            }),
+        **dict.fromkeys(
+            ['b', 'base'], {
+                'embed_dims': 768,
+                'num_layers': 12,
+                'num_heads': 12,
+                'feedforward_channels': int(768 * 4 * 2 / 3)
+            }),
+        **dict.fromkeys(
+            ['l', 'large'], {
+                'embed_dims': 1024,
+                'num_layers': 24,
+                'num_heads': 16,
+                'feedforward_channels': int(1024 * 4 * 2 / 3)
+            })
+    }
+    num_extra_tokens = 1  # class token
+    OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
+
+    def __init__(self,
+                 arch='tiny',
+                 sub_ln=False,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 proj_drop_rate=0.,
+                 drop_path_rate=0.,
+                 qkv_bias=True,
+                 norm_cfg=dict(type='LN'),
+                 with_cls_token=True,
+                 layer_cfgs=dict(),
+                 **kwargs):
+        # set essential args for Vision Transformer
+        kwargs.update(
+            arch=arch,
+            drop_rate=drop_rate,
+            drop_path_rate=drop_path_rate,
+            norm_cfg=norm_cfg,
+            with_cls_token=with_cls_token)
+        super(ViTEVA02, self).__init__(**kwargs)
+
+        self.num_heads = self.arch_settings['num_heads']
+
+        # Set RoPE
+        head_dim = self.embed_dims // self.num_heads
+        self.rope = RotaryEmbeddingFast(
+            embed_dims=head_dim, patch_resolution=self.patch_resolution)
+
+        # stochastic depth decay rule
+        dpr = np.linspace(0, drop_path_rate, self.num_layers)
+        self.layers = ModuleList()
+        if isinstance(layer_cfgs, dict):
+            layer_cfgs = [layer_cfgs] * self.num_layers
+        for i in range(self.num_layers):
+            _layer_cfg = dict(
+                embed_dims=self.embed_dims,
+                num_heads=self.num_heads,
+                feedforward_channels=self.
+                arch_settings['feedforward_channels'],
+                sub_ln=sub_ln,
+                norm_cfg=norm_cfg,
+                proj_drop=proj_drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_rate=drop_rate,
+                qkv_bias=qkv_bias,
+                rope=self.rope,
+                with_cls_token=with_cls_token,
+                drop_path_rate=dpr[i])
+            _layer_cfg.update(layer_cfgs[i])
+            self.layers.append(EVA02EndcoderLayer(**_layer_cfg))
+
+    def forward(self, x):
+        B = x.shape[0]
+        x, patch_resolution = self.patch_embed(x)
+
+        if self.cls_token is not None:
+            # stole cls_tokens impl from Phil Wang, thanks
+            cls_tokens = self.cls_token.expand(B, -1, -1)
+            x = torch.cat((cls_tokens, x), dim=1)
+
+        x = x + resize_pos_embed(
+            self.pos_embed,
+            self.patch_resolution,
+            patch_resolution,
+            mode=self.interpolate_mode,
+            num_extra_tokens=self.num_extra_tokens)
+        x = self.drop_after_pos(x)
+
+        x = self.pre_norm(x)
+
+        outs = []
+        for i, layer in enumerate(self.layers):
+            x = layer(x, patch_resolution)
+
+            if i == len(self.layers) - 1 and self.final_norm:
+                x = self.ln1(x)
+
+            if i in self.out_indices:
+                outs.append(self._format_output(x, patch_resolution))
+
+        return tuple(outs)
diff --git a/mmpretrain/models/utils/__init__.py b/mmpretrain/models/utils/__init__.py
index 904de6b7..b7df9e41 100644
--- a/mmpretrain/models/utils/__init__.py
+++ b/mmpretrain/models/utils/__init__.py
@@ -18,7 +18,7 @@ from .layer_scale import LayerScale
 from .make_divisible import make_divisible
 from .norm import GRN, LayerNorm2d, build_norm_layer
 from .position_encoding import (ConditionalPositionEncoding,
-                                PositionEncodingFourier,
+                                PositionEncodingFourier, RotaryEmbeddingFast,
                                 build_2d_sincos_position_embedding)
 from .res_layer_extra_norm import ResLayerExtraNorm
 from .se_layer import SELayer
@@ -72,4 +72,5 @@ __all__ = [
     'ResLayerExtraNorm',
     'SwiGLUFFN',
     'SwiGLUFFNFused',
+    'RotaryEmbeddingFast',
 ]
diff --git a/mmpretrain/models/utils/position_encoding.py b/mmpretrain/models/utils/position_encoding.py
index a200c066..07a3c486 100644
--- a/mmpretrain/models/utils/position_encoding.py
+++ b/mmpretrain/models/utils/position_encoding.py
@@ -8,6 +8,8 @@ import torch.nn as nn
 from mmengine.model import BaseModule
 from mmengine.utils import digit_version
 
+from ..utils import to_2tuple
+
 # After pytorch v1.10.0, use torch.meshgrid without indexing
 # will raise extra warning. For more details,
 # refers to https://github.com/pytorch/pytorch/issues/50276
@@ -170,3 +172,76 @@ def build_2d_sincos_position_embedding(
         pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1)
 
     return pos_emb
+
+
+class RotaryEmbeddingFast(BaseModule):
+    """Implements 2D rotary embedding (RoPE) for image tokens. Position
+    encoding is implemented with sin and cos functions,
+
+        .. math::
+            Pos_{cos} = cos(\frac{t}{\theta^{\frac{2i}{d}}} \\
+            Pos_{sin} = sin(\frac{t}{\theta^{\frac{2i}{d}}}
+    Args:
+        embed_dims (int): The feature dimension for each head.
+        patch_resolution (int | tuple): The resolution of the
+            image, in format (H, W).
+        theta (float): The hyperparameter for position coding.
+            Defaults to 10000.
+        init_cfg (dict, optional): Initialization config dict.
+            Defaults to None.
+    """
+
+    def __init__(self,
+                 embed_dims,
+                 patch_resolution,
+                 theta=10000.,
+                 init_cfg=None):
+        super(RotaryEmbeddingFast, self).__init__(init_cfg=init_cfg)
+
+        self.half_dim = embed_dims // 2
+        self.patch_resolution = to_2tuple(patch_resolution)
+        self.theta = theta
+
+        freqs_cos, freqs_sin = self.compute_position_embedding()
+        self.register_buffer('freqs_cos', freqs_cos)
+        self.register_buffer('freqs_sin', freqs_sin)
+
+    def compute_position_embedding(self):
+        frequency = self.theta**(
+            torch.arange(0, self.half_dim, 2).float() / self.half_dim)
+        frequency = 1. / frequency
+
+        h, w = self.patch_resolution
+        th = torch.arange(h) / h * self.half_dim
+        tw = torch.arange(w) / w * self.half_dim
+
+        position_h = (th[:, None] @ frequency[None, :]).repeat(1, 2)
+        position_w = (tw[:, None] @ frequency[None, :]).repeat(1, 2)
+
+        height = position_h[:, None, :].expand(h, w, self.half_dim)
+        width = position_w[None, :, :].expand(h, w, self.half_dim)
+        position = torch.cat((height, width), dim=-1)
+
+        freqs_cos = position.cos().view(-1, position.shape[-1])
+        freqs_sin = position.sin().view(-1, position.shape[-1])
+
+        return freqs_cos, freqs_sin
+
+    def forward(self, x, patch_resolution):
+        # Check whether the patch resolution is the predefined size
+        patch_resolution = to_2tuple(patch_resolution)
+        if patch_resolution != self.patch_resolution:
+            self.patch_resolution = patch_resolution
+            freqs_cos, freqs_sin = self.compute_position_embedding()
+            self.register_buffer('freqs_cos', freqs_cos.to(x.device))
+            self.register_buffer('freqs_sin', freqs_sin.to(x.device))
+
+        batch, num_heads, num_patches, dim = x.shape
+
+        inputs = x
+        x = x.reshape(batch, num_heads, num_patches, -1, 2)
+        x1, x2 = x.unbind(dim=-1)
+        x = torch.stack((-x2, x1), dim=-1)
+        x = x.reshape(batch, num_heads, num_patches, dim)
+
+        return inputs * self.freqs_cos + x * self.freqs_sin
diff --git a/model-index.yml b/model-index.yml
index 8df1d3d3..c960b360 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -69,4 +69,5 @@ Import:
   - configs/riformer/metafile.yml
   - configs/sam/metafile.yml
   - configs/glip/metafile.yml
+  - configs/eva02/metafile.yml
   - configs/dinov2/metafile.yml
diff --git a/tests/test_models/test_backbones/test_eva02.py b/tests/test_models/test_backbones/test_eva02.py
new file mode 100644
index 00000000..06727542
--- /dev/null
+++ b/tests/test_models/test_backbones/test_eva02.py
@@ -0,0 +1,143 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+from unittest import TestCase
+
+import torch
+
+from mmpretrain.models.backbones import ViTEVA02
+
+
+class TestEVA02(TestCase):
+
+    def setUp(self):
+        self.cfg = dict(
+            arch='t',
+            img_size=336,
+            patch_size=14,
+            drop_path_rate=0.1,
+            drop_rate=0.1,
+            attn_drop_rate=0.2,
+            proj_drop_rate=0.3,
+        )
+
+    def test_structure(self):
+        # Test invalid default arch
+        with self.assertRaisesRegex(AssertionError, 'not in default archs'):
+            cfg = deepcopy(self.cfg)
+            cfg['arch'] = 'unknown'
+            ViTEVA02(**cfg)
+
+        # Test invalid custom arch
+        with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
+            cfg = deepcopy(self.cfg)
+            cfg['arch'] = {
+                'num_layers': 24,
+                'num_heads': 16,
+                'feedforward_channels': int(24 * 4 * 2 / 3)
+            }
+            ViTEVA02(**cfg)
+
+        # Test custom arch
+        cfg = deepcopy(self.cfg)
+        cfg['arch'] = {
+            'embed_dims': 128,
+            'num_layers': 6,
+            'num_heads': 16,
+            'feedforward_channels': int(128 * 4 * 2 / 3)
+        }
+        model = ViTEVA02(**cfg)
+        self.assertEqual(model.embed_dims, 128)
+        self.assertEqual(model.num_layers, 6)
+        for layer in model.layers:
+            self.assertEqual(layer.attn.num_heads, 16)
+
+        # Test out_indices
+        cfg = deepcopy(self.cfg)
+        cfg['out_indices'] = {1: 1}
+        with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
+            ViTEVA02(**cfg)
+        cfg['out_indices'] = [0, 13]
+        with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'):
+            ViTEVA02(**cfg)
+
+        # Test model structure
+        cfg = deepcopy(self.cfg)
+        model = ViTEVA02(**cfg)
+        self.assertEqual(len(model.layers), 12)
+        self.assertEqual(model.cls_token.shape, (1, 1, 192))
+        self.assertEqual(model.pos_embed.shape, (1, 577, 192))
+        dpr_inc = 0.1 / (12 - 1)
+        dpr = 0
+        for layer in model.layers:
+            self.assertEqual(layer.attn.embed_dims, 192)
+            self.assertEqual(layer.attn.num_heads, 3)
+            self.assertAlmostEqual(layer.drop_path.drop_prob, dpr)
+            self.assertAlmostEqual(layer.mlp.dropout_layer.p, 0.1)
+            self.assertAlmostEqual(layer.attn.attn_drop.p, 0.2)
+            self.assertAlmostEqual(layer.attn.proj_drop.p, 0.3)
+            dpr += dpr_inc
+
+        # Test model structure: final_norm
+        cfg = deepcopy(self.cfg)
+        cfg['final_norm'] = True
+        model = ViTEVA02(**cfg)
+        self.assertNotEqual(model.norm1.__class__, torch.nn.Identity)
+
+    def test_forward(self):
+        imgs = torch.randn(1, 3, 336, 336)
+
+        # test with_cls_token=False
+        cfg = deepcopy(self.cfg)
+        cfg['with_cls_token'] = False
+        cfg['out_type'] = 'cls_token'
+        with self.assertRaisesRegex(ValueError, 'must be True'):
+            ViTEVA02(**cfg)
+
+        cfg = deepcopy(self.cfg)
+        cfg['with_cls_token'] = False
+        cfg['out_type'] = 'raw'
+        model = ViTEVA02(**cfg)
+        outs = model(imgs)
+        self.assertIsInstance(outs, tuple)
+        self.assertEqual(len(outs), 1)
+        patch_token = outs[-1]
+        self.assertEqual(patch_token.shape, (1, 24 * 24, 192))
+
+        cfg = deepcopy(self.cfg)
+        cfg['with_cls_token'] = False
+        cfg['out_type'] = 'featmap'
+        model = ViTEVA02(**cfg)
+        outs = model(imgs)
+        self.assertIsInstance(outs, tuple)
+        self.assertEqual(len(outs), 1)
+        patch_token = outs[-1]
+        self.assertEqual(patch_token.shape, (1, 192, 24, 24))
+
+        cfg = deepcopy(self.cfg)
+        cfg['with_cls_token'] = False
+        cfg['out_type'] = 'avg_featmap'
+        model = ViTEVA02(**cfg)
+        outs = model(imgs)
+        self.assertIsInstance(outs, tuple)
+        self.assertEqual(len(outs), 1)
+        patch_token = outs[-1]
+        self.assertEqual(patch_token.shape, (1, 192))
+
+        # test with output cls_token
+        cfg = deepcopy(self.cfg)
+        model = ViTEVA02(**cfg)
+        outs = model(imgs)
+        self.assertIsInstance(outs, tuple)
+        self.assertEqual(len(outs), 1)
+        cls_token = outs[-1]
+        self.assertEqual(cls_token.shape, (1, 192))
+
+        # Test forward with multi out indices
+        cfg = deepcopy(self.cfg)
+        cfg['out_indices'] = [-3, -2, -1]
+        model = ViTEVA02(**cfg)
+        outs = model(imgs)
+        self.assertIsInstance(outs, tuple)
+        self.assertEqual(len(outs), 3)
+        for out in outs:
+            self.assertEqual(out.shape, (1, 192))
diff --git a/tests/test_models/test_utils/test_position_encoding.py b/tests/test_models/test_utils/test_position_encoding.py
index 221a20df..7d80023c 100644
--- a/tests/test_models/test_utils/test_position_encoding.py
+++ b/tests/test_models/test_utils/test_position_encoding.py
@@ -1,10 +1,21 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import torch
 
-from mmpretrain.models.utils import ConditionalPositionEncoding
+from mmpretrain.models.utils import (ConditionalPositionEncoding,
+                                     RotaryEmbeddingFast)
 
 
 def test_conditional_position_encoding_module():
     CPE = ConditionalPositionEncoding(in_channels=32, embed_dims=32, stride=2)
     outs = CPE(torch.randn(1, 3136, 32), (56, 56))
     assert outs.shape == torch.Size([1, 784, 32])
+
+
+def test_rotary_embedding_fast_module():
+    RoPE = RotaryEmbeddingFast(embed_dims=64, patch_resolution=24)
+    outs = RoPE(torch.randn(1, 2, 24 * 24, 64), (24, 24))
+    assert outs.shape == torch.Size([1, 2, 24 * 24, 64])
+
+    RoPE = RotaryEmbeddingFast(embed_dims=64, patch_resolution=(14, 20))
+    outs = RoPE(torch.randn(1, 2, 14 * 20, 64), (14, 20))
+    assert outs.shape == torch.Size([1, 2, 14 * 20, 64])
diff --git a/tools/model_converters/eva02_to_mmpretrain.py b/tools/model_converters/eva02_to_mmpretrain.py
new file mode 100644
index 00000000..e5a8682f
--- /dev/null
+++ b/tools/model_converters/eva02_to_mmpretrain.py
@@ -0,0 +1,153 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import os.path as osp
+from collections import OrderedDict
+
+import mmengine
+import torch
+from mmengine.runner import CheckpointLoader
+
+
+def convert_eva02(ckpt):
+
+    new_ckpt = OrderedDict()
+    qkv_proj = {}
+    qkv_bias = {}
+    w12_weight = {}
+    w12_bias = {}
+
+    banned = {
+        'mask_token',
+        'lm_head.weight',
+        'lm_head.bias',
+        'norm.weight',
+        'norm.bias',
+    }
+
+    for k, v in list(ckpt.items()):
+
+        if k in banned:
+            continue
+
+        if k.startswith('head'):
+            new_k = k.replace('head.', 'head.fc.')
+            new_ckpt[new_k] = v
+        else:
+            if k.startswith('patch_embed'):
+                new_k = k.replace('proj.', 'projection.')
+
+            elif k.startswith('fc_norm') or k.startswith('norm'):
+                new_k = k.replace('norm.', 'ln2.')
+                new_k = k.replace('fc_norm.', 'ln2.')
+
+            elif k.startswith('blocks'):
+                new_k = k.replace('blocks.', 'layers.')
+
+                if 'mlp' in new_k:
+                    if 'w1.' in new_k or 'w2.' in new_k:
+                        # For base and large version, mlp is implemented with
+                        # 2 linears, where w1 and w2 are required to integrate
+                        # into w12.
+                        s = new_k.split('.')  # e.g. layers.0.mlp.w1.weight
+                        idx = s[1]
+                        if 'weight' in new_k:
+                            # w1.weight or w2.weight
+                            if idx not in w12_weight:
+                                w12_weight[idx] = {}
+                            w12_weight[idx][s[-2]] = v
+                        else:
+                            # w1.bias or w2.bias
+                            if idx not in w12_bias:
+                                w12_bias[idx] = {}
+                            w12_bias[idx][s[-2]] = v
+                        continue
+
+                    if 'ffn_ln' in new_k:
+                        new_k = new_k.replace('ffn_ln.', 'norm.')
+
+                elif 'attn' in new_k:
+                    if 'q_proj.weight' in new_k or \
+                            'k_proj.weight' in new_k or \
+                            'v_proj.weight' in new_k:
+                        # For base and large version, qkv projection is
+                        # implemented with three linear layers,
+                        s = new_k.split('.')
+                        idx = s[1]
+                        if idx not in qkv_proj:
+                            qkv_proj[idx] = {}
+                        qkv_proj[idx][s[-2]] = v
+                        continue
+
+                    if 'q_bias' in new_k or 'v_bias' in new_k:
+                        # k_bias is 0
+                        s = new_k.split('.')
+                        idx = s[1]
+                        if idx not in qkv_bias:
+                            qkv_bias[idx] = {}
+                        qkv_bias[idx][s[-1]] = v
+                        continue
+
+            else:
+                new_k = k
+
+            new_k = 'backbone.' + new_k
+            new_ckpt[new_k] = v
+
+    for idx in qkv_proj:
+        q_proj = qkv_proj[idx]['q_proj']
+        k_proj = qkv_proj[idx]['k_proj']
+        v_proj = qkv_proj[idx]['v_proj']
+        weight = torch.cat((q_proj, k_proj, v_proj))
+        new_k = f'backbone.layers.{idx}.attn.qkv.weight'
+        new_ckpt[new_k] = weight
+
+    for idx in qkv_bias:
+        q_bias = qkv_bias[idx]['q_bias']
+        k_bias = torch.zeros_like(q_bias)
+        v_bias = qkv_bias[idx]['v_bias']
+        weight = torch.cat((q_bias, k_bias, v_bias))
+        new_k = f'backbone.layers.{idx}.attn.qkv.bias'
+        new_ckpt[new_k] = weight
+
+    for idx in w12_weight:
+        w1 = w12_weight[idx]['w1']
+        w2 = w12_weight[idx]['w2']
+        weight = torch.cat((w1, w2))
+        new_k = f'backbone.layers.{idx}.mlp.w12.weight'
+        new_ckpt[new_k] = weight
+
+    for idx in w12_bias:
+        w1 = w12_bias[idx]['w1']
+        w2 = w12_bias[idx]['w2']
+        weight = torch.cat((w1, w2))
+        new_k = f'backbone.layers.{idx}.mlp.w12.bias'
+        new_ckpt[new_k] = weight
+
+    return new_ckpt
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        description='Convert keys in pretrained eva02 '
+        'models to mmpretrain style.')
+    parser.add_argument('src', help='src model path or url')
+    # The dst path must be a full path of the new checkpoint.
+    parser.add_argument('dst', help='save path')
+    args = parser.parse_args()
+
+    checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
+
+    if 'module' in checkpoint:
+        state_dict = checkpoint['module']
+    else:
+        state_dict = checkpoint
+
+    weight = convert_eva02(state_dict)
+    mmengine.mkdir_or_exist(osp.dirname(args.dst))
+    torch.save(weight, args.dst)
+
+    print('Done!!')
+
+
+if __name__ == '__main__':
+    main()