[Feature] add eva02 backbone (#1450)

* [CI] Add test mim CI. (#879)

* [CI] Add test mim CI. (#879)

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update ci

* rebase

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update readme and configs

* update readme and configs

* refactore eva02

* [CI] Add test mim CI. (#879)

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update ci

* rebase

* feat: add eva02 backbone

* feat: add eva02 backbone

* feat: add eva02 backbone

* update

* update readme and configs

* refactore eva02

* update readme and metafile

* update readme and metafile

* update readme and metafile

* update

* rename eva02

* rename eva02

* fix uts

* rename configs

---------

Co-authored-by: Ma Zerun <mzr1996@163.com>
Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
pull/1503/head
zzc98 2023-05-06 19:28:31 +08:00 committed by GitHub
parent 7f4eccbecf
commit 034919d032
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1317 additions and 2 deletions

View File

@ -0,0 +1,62 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
# convert image from BGR to RGB
to_rgb=True,
)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=448,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=448,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=448),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=8,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,109 @@
# EVA-02
> [EVA-02: A Visual Representation for Neon Genesis](https://arxiv.org/abs/2303.11331)
<!-- [ALGORITHM] -->
## Abstract
We launch EVA-02, a next-generation Transformer-based visual representation pre-trained to reconstruct strong and robust language-aligned vision features via masked image modeling. With an updated plain Transformer architecture as well as extensive pre-training from an open & accessible giant CLIP vision encoder, EVA-02 demonstrates superior performance compared to prior state-of-the-art approaches across various representative vision tasks, while utilizing significantly fewer parameters and compute budgets. Notably, using exclusively publicly accessible training data, EVA-02 with only 304M parameters achieves a phenomenal 90.0 fine-tuning top-1 accuracy on ImageNet-1K val set. Additionally, our EVA-02-CLIP can reach up to 80.4 zero-shot top-1 on ImageNet-1K, outperforming the previous largest & best open-sourced CLIP with only ~1/6 parameters and ~1/6 image-text training data. We offer four EVA-02 variants in various model sizes, ranging from 6M to 304M parameters, all with impressive performance. To facilitate open accessand open research, we release the complete suite of EVA-02 to the community.
<div align=center>
<img src="https://user-images.githubusercontent.com/40905160/229037980-b83dceb5-41d6-406c-a20b-63b83c80136d.png" width="70%" alt="TrV builds upon the original plain ViT architecture and includes several enhancements: SwinGLU FFN, sub-LN, 2D RoPE, and JAX weight initialization. To keep the parameter & FLOPs consistent with the baseline, the FFN hidden dim of SwiGLU is 2/3× of the typical MLP counterpart."/>
</div>
## How to use it?
<!-- [TABS-BEGIN] -->
**Predict image**
```python
from mmpretrain import inference_model
predict = inference_model('vit-tiny-p14_eva02-in21k-pre_3rdparty_in1k-336px', 'demo/bird.JPEG')
print(predict['pred_class'])
print(predict['pred_score'])
```
**Use the model**
```python
import torch
from mmpretrain import get_model
model = get_model('vit-tiny-p14_eva02-in21k-pre_3rdparty_in1k-336px', pretrained=True)
inputs = torch.rand(1, 3, 336, 336)
out = model(inputs)
print(type(out))
# To extract features.
feats = model.extract_feat(inputs)
print(type(feats))
```
**Train/Test Command**
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
Train:
```shell
python tools/train.py configs/eva02/eva02-tiny-p14_in1k.py
```
Test:
```shell
python tools/test.py configs/eva02/eva02-tiny-p14_in1k.py /path/to/eva02-tiny-p14_in1k.pth
```
<!-- [TABS-END] -->
## Models and results
### Pretrained models
| Model | Params (M) | Flops (G) | Config | Download |
| :-------------------------------- | :--------: | :-------: | :-----------------------------------: | :-----------------------------------------------------------------------------------------------------------: |
| `vit-tiny-p14_eva02-pre_in21k`\* | 5.50 | 1.70 | [config](eva02-tiny-p14_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-tiny-p14_pre_in21k_20230505-d703e7b1.pth) |
| `vit-small-p14_eva02-pre_in21k`\* | 21.62 | 6.14 | [config](eva02-small-p14_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-small-p14_pre_in21k_20230505-3175f463.pth) |
| `vit-base-p14_eva02-pre_in21k`\* | 85.77 | 23.22 | [config](eva02-base-p14_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_pre_in21k_20230505-2f2d4d3c.pth) |
| `vit-large-p14_eva02-pre_in21k`\* | 303.29 | 81.15 | [config](eva02-large-p14_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_pre_in21k_20230505-9072de5d.pth) |
| `vit-large-p14_eva02-pre_m38m`\* | 303.29 | 81.15 | [config](eva02-large-p14_headless.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_pre_m38m_20230505-b8a1a261.pth) |
- The input size / patch size of MIM pre-trained EVA-02 is `224x224` / `14x14`.
*Models with * are converted from the [official repo](https://github.com/baaivision/EVA).*
### Image Classification on ImageNet-1k
#### (*w/o* IN-21K intermediate fine-tuning)
| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Top-5 (%) | Config | Download |
| :---------------------------------------------------- | :----------------: | :--------: | :-------: | :-------: | :-------: | :---------------------------------: | :-------------------------------------------------------: |
| `vit-tiny-p14_eva02-in21k-pre_3rdparty_in1k-336px`\* | EVA02 ImageNet-21k | 5.76 | 4.68 | 80.69 | 95.54 | [config](./eva02-tiny-p14_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-tiny-p14_in21k-pre_3rdparty_in1k-336px_20230505-a4e8708a.pth) |
| `vit-small-p14_eva02-in21k-pre_3rdparty_in1k-336px`\* | EVA02 ImageNet-21k | 22.13 | 15.48 | 85.78 | 97.60 | [config](./eva02-small-p14_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-small-p14_in21k-pre_3rdparty_in1k-336px_20230505-9c5b0e85.pth) |
| `vit-base-p14_eva02-in21k-pre_3rdparty_in1k-448px`\* | EVA02 ImageNet-21k | 87.13 | 107.11 | 88.29 | 98.53 | [config](./eva02-base-p14_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_in21k-pre_3rdparty_in1k-448px_20230505-8ad211c5.pth) |
*Models with * are converted from the [official repo](https://github.com/baaivision/EVA/tree/master/EVA-02). The config files of these models are only for inference. We haven't reprodcue the training results.*
#### (*w* IN-21K intermediate fine-tuning)
| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Top-5 (%) | Config | Download |
| :---------------------------------------------------- | :----------------: | :--------: | :-------: | :-------: | :-------: | :---------------------------------: | :-------------------------------------------------------: |
| `vit-base-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px`\* | EVA02 ImageNet-21k | 87.13 | 107.11 | 88.47 | 98.62 | [config](./eva02-base-p14_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_in21k-pre_in21k-medft_3rdparty_in1k-448px_20230505-5cd4d87f.pth) |
| `vit-large-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px`\* | EVA02 ImageNet-21k | 305.08 | 362.33 | 89.65 | 98.95 | [config](./eva02-large-p14_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_in21k-pre_in21k-medft_3rdparty_in1k-448px_20230505-926d1599.pth) |
| `vit-large-p14_eva02_m38m-pre_in21k-medft_3rdparty_in1k-448px`\* | EVA02 Merged-38M | 305.10 | 362.33 | 89.83 | 99.00 | [config](./eva02-large-p14_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_m38m-pre_in21k-medft_3rdparty_in1k-448px_20230505-150dc5ed.pth) |
*Models with * are converted from the [official repo](https://github.com/baaivision/EVA/tree/master/EVA-02). The config files of these models are only for inference. We haven't reprodcue the training results.*
## Citation
```bibtex
@article{EVA-02,
title={EVA-02: A Visual Representation for Neon Genesis},
author={Yuxin Fang and Quan Sun and Xinggang Wang and Tiejun Huang and Xinlong Wang and Yue Cao},
journal={arXiv preprint arXiv:2303.11331},
year={2023}
}
```

View File

@ -0,0 +1,21 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='ViTEVA02',
arch='b',
img_size=224,
patch_size=14,
sub_ln=True,
final_norm=False,
out_type='avg_featmap'),
neck=None,
head=None,
)
data_preprocessor = dict(
# RGB format normalization parameters
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
# convert image from BGR to RGB
to_rgb=True,
)

View File

@ -0,0 +1,32 @@
_base_ = [
'../_base_/datasets/imagenet_bs16_eva_448.py',
'../_base_/schedules/imagenet_bs2048_AdamW.py',
'../_base_/default_runtime.py'
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='ViTEVA02',
arch='b',
img_size=448,
patch_size=14,
sub_ln=True,
final_norm=False,
out_type='avg_featmap'),
neck=None,
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))

View File

@ -0,0 +1,21 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='ViTEVA02',
arch='l',
img_size=224,
patch_size=14,
sub_ln=True,
final_norm=False,
out_type='avg_featmap'),
neck=None,
head=None,
)
data_preprocessor = dict(
# RGB format normalization parameters
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
# convert image from BGR to RGB
to_rgb=True,
)

View File

@ -0,0 +1,32 @@
_base_ = [
'../_base_/datasets/imagenet_bs16_eva_448.py',
'../_base_/schedules/imagenet_bs2048_AdamW.py',
'../_base_/default_runtime.py'
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='ViTEVA02',
arch='l',
img_size=448,
patch_size=14,
sub_ln=True,
final_norm=False,
out_type='avg_featmap'),
neck=None,
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))

View File

@ -0,0 +1,20 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='ViTEVA02',
arch='s',
img_size=224,
patch_size=14,
final_norm=False,
out_type='avg_featmap'),
neck=None,
head=None,
)
data_preprocessor = dict(
# RGB format normalization parameters
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
# convert image from BGR to RGB
to_rgb=True,
)

View File

@ -0,0 +1,31 @@
_base_ = [
'../_base_/datasets/imagenet_bs16_eva_336.py',
'../_base_/schedules/imagenet_bs2048_AdamW.py',
'../_base_/default_runtime.py'
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='ViTEVA02',
arch='s',
img_size=336,
patch_size=14,
final_norm=False,
out_type='avg_featmap'),
neck=None,
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=384,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))

View File

@ -0,0 +1,20 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='ViTEVA02',
arch='t',
img_size=224,
patch_size=14,
final_norm=False,
out_type='avg_featmap'),
neck=None,
head=None,
)
data_preprocessor = dict(
# RGB format normalization parameters
mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
# convert image from BGR to RGB
to_rgb=True,
)

View File

@ -0,0 +1,31 @@
_base_ = [
'../_base_/datasets/imagenet_bs16_eva_336.py',
'../_base_/schedules/imagenet_bs2048_AdamW.py',
'../_base_/default_runtime.py'
]
model = dict(
type='ImageClassifier',
backbone=dict(
type='ViTEVA02',
arch='t',
img_size=336,
patch_size=14,
final_norm=False,
out_type='avg_featmap'),
neck=None,
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=192,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))

View File

@ -0,0 +1,199 @@
Collections:
- Name: EVA02
Metadata:
Architecture:
- Rotary Position Embedding
- Sub Layer Normalization
- SwiGLU
Paper:
Title: 'EVA-02: A Visual Representation for Neon Genesis'
URL: https://arxiv.org/abs/2303.11331
README: configs/eva02/README.md
Models:
- Name: vit-tiny-p14_eva02-pre_in21k
Metadata:
FLOPs: 1703439360
Parameters: 5504064
Training Data:
- ImageNet-21k
In Collection: EVA02
Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-tiny-p14_pre_in21k_20230505-d703e7b1.pth
Config: configs/eva02/eva02-tiny-p14_headless.py
Converted From:
Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/pt/eva02_Ti_pt_in21k_p14.pt
Code: https://github.com/baaivision/EVA/tree/master/EVA-02
Downstream:
- vit-tiny-p14_eva02-in21k-pre_3rdparty_in1k-336px
- Name: vit-tiny-p14_eva02-in21k-pre_3rdparty_in1k-336px
Metadata:
FLOPs: 4675416000
Parameters: 5758888
Training Data:
- ImageNet-21k
- ImageNet-1k
In Collection: EVA02
Results:
- Dataset: ImageNet-1k
Task: Image Classification
Metrics:
Top 1 Accuracy: 80.69
Top 5 Accuracy: 95.54
Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-tiny-p14_in21k-pre_3rdparty_in1k-336px_20230505-a4e8708a.pth
Config: configs/eva02/eva02-tiny-p14_in1k.py
Converted From:
Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in1k/eva02_Ti_pt_in21k_ft_in1k_p14.pt
Code: https://github.com/baaivision/EVA/tree/master/EVA-02
- Name: vit-small-p14_eva02-pre_in21k
Metadata:
FLOPs: 6135404544
Parameters: 21624960
Training Data:
- ImageNet-21k
In Collection: EVA02
Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-small-p14_pre_in21k_20230505-3175f463.pth
Config: configs/eva02/eva02-small-p14_headless.py
Converted From:
Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/pt/eva02_S_pt_in21k_p14.pt
Code: https://github.com/baaivision/EVA/tree/master/EVA-02
Downstream:
- vit-small-p14_eva02-in21k-pre_3rdparty_in1k-336px
- Name: vit-small-p14_eva02-in21k-pre_3rdparty_in1k-336px
Metadata:
FLOPs: 15476744064
Parameters: 22133608
Training Data:
- ImageNet-21k
- ImageNet-1k
In Collection: EVA02
Results:
- Dataset: ImageNet-1k
Task: Image Classification
Metrics:
Top 1 Accuracy: 85.78
Top 5 Accuracy: 97.60
Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-small-p14_in21k-pre_3rdparty_in1k-336px_20230505-9c5b0e85.pth
Config: configs/eva02/eva02-small-p14_in1k.py
Converted From:
Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in1k/eva02_S_pt_in21k_ft_in1k_p14.pt
Code: https://github.com/baaivision/EVA/tree/master/EVA-02
- Name: vit-base-p14_eva02-pre_in21k
Metadata:
FLOPs: 23216492544
Parameters: 85766400
Training Data:
- ImageNet-21k
In Collection: EVA02
Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_pre_in21k_20230505-2f2d4d3c.pth
Config: configs/eva02/eva02-base-p14_headless.py
Converted From:
Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/pt/eva02_B_pt_in21k_p14.pt
Code: https://github.com/baaivision/EVA/tree/master/EVA-02
Downstream:
- vit-base-p14_eva02-in21k-pre_3rdparty_in1k-448px
- vit-base-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px
- Name: vit-base-p14_eva02-in21k-pre_3rdparty_in1k-448px
Metadata:
FLOPs: 107105984256
Parameters: 87126760
Training Data:
- ImageNet-21k
- ImageNet-1k
In Collection: EVA02
Results:
- Dataset: ImageNet-1k
Task: Image Classification
Metrics:
Top 1 Accuracy: 88.29
Top 5 Accuracy: 98.53
Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_in21k-pre_3rdparty_in1k-448px_20230505-8ad211c5.pth
Config: configs/eva02/eva02-base-p14_in1k.py
Converted From:
Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in1k/eva02_B_pt_in21k_ft_in1k_p14.pt
Code: https://github.com/baaivision/EVA/tree/master/EVA-02
- Name: vit-base-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px
Metadata:
FLOPs: 107105984256
Parameters: 87126760
Training Data:
- ImageNet-21k
- ImageNet-1k
In Collection: EVA02
Results:
- Dataset: ImageNet-1k
Task: Image Classification
Metrics:
Top 1 Accuracy: 88.47
Top 5 Accuracy: 98.62
Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-base-p14_in21k-pre_in21k-medft_3rdparty_in1k-448px_20230505-5cd4d87f.pth
Config: configs/eva02/eva02-base-p14_in1k.py
Converted From:
Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in21k/eva02_B_pt_in21k_medft_in21k_p14.pt
Code: https://github.com/baaivision/EVA/tree/master/EVA-02
- Name: vit-large-p14_eva02-pre_in21k
Metadata:
FLOPs: 81146703792
Parameters: 303291328
Training Data:
- ImageNet-21k
In Collection: EVA02
Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_pre_in21k_20230505-9072de5d.pth
Config: configs/eva02/eva02-large-p14_headless.py
Converted From:
Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/pt/eva02_L_pt_in21k_p14.pt
Code: https://github.com/baaivision/EVA/tree/master/EVA-02
Downstream:
- vit-large-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px
- Name: vit-large-p14_eva02-in21k-pre_in21k-medft_3rdparty_in1k-448px
Metadata:
FLOPs: 362333836208
Parameters: 305104808
Training Data:
- ImageNet-21k
- ImageNet-1k
In Collection: EVA02
Results:
- Dataset: ImageNet-1k
Task: Image Classification
Metrics:
Top 1 Accuracy: 89.65
Top 5 Accuracy: 98.95
Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_in21k-pre_in21k-medft_3rdparty_in1k-448px_20230505-926d1599.pth
Config: configs/eva02/eva02-large-p14_in1k.py
Converted From:
Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in21k/eva02_L_pt_in21k_medft_in21k_p14.pt
Code: https://github.com/baaivision/EVA/tree/master/EVA-02
- Name: vit-large-p14_eva02-pre_m38m
Metadata:
FLOPs: 81146703792
Parameters: 303291328
Training Data:
- Merged-38M
In Collection: EVA02
Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_pre_m38m_20230505-b8a1a261.pth
Config: configs/eva02/eva02-large-p14_headless.py
Converted From:
Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/pt/eva02_L_pt_m38m_p14.pt
Code: https://github.com/baaivision/EVA/tree/master/EVA-02
Downstream:
- vit-large-p14_eva02_m38m-pre_in21k-medft_3rdparty_in1k-448px
- Name: vit-large-p14_eva02_m38m-pre_in21k-medft_3rdparty_in1k-448px
Metadata:
FLOPs: 362333836208
Parameters: 305104808
Training Data:
- Merged-38M
- ImageNet-21k
- ImageNet-1k
In Collection: EVA02
Results:
- Dataset: ImageNet-1k
Task: Image Classification
Metrics:
Top 1 Accuracy: 89.83
Top 5 Accuracy: 99.00
Weights: https://download.openmmlab.com/mmpretrain/v1.0/eva02/eva02-large-p14_m38m-pre_in21k-medft_3rdparty_in1k-448px_20230505-150dc5ed.pth
Config: configs/eva02/eva02-large-p14_in1k.py
Converted From:
Weights: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/cls/in21k/eva02_L_pt_m38m_medft_in21k_p14.pt
Code: https://github.com/baaivision/EVA/tree/master/EVA-02

View File

@ -189,6 +189,7 @@ Backbones
VisionTransformer
ViTSAM
XCiT
ViTEVA02
.. module:: mmpretrain.models.necks

View File

@ -52,6 +52,7 @@ from .van import VAN
from .vgg import VGG
from .vig import PyramidVig, Vig
from .vision_transformer import VisionTransformer
from .vit_eva02 import ViTEVA02
from .vit_sam import ViTSAM
from .xcit import XCiT
@ -118,4 +119,5 @@ __all__ = [
'PyramidVig',
'XCiT',
'ViTSAM',
'ViTEVA02',
]

View File

@ -0,0 +1,350 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn.bricks.drop import build_dropout
from mmengine.model import BaseModule, ModuleList
from mmpretrain.registry import MODELS
from ..utils import (RotaryEmbeddingFast, SwiGLUFFN, build_norm_layer,
resize_pos_embed)
from .vision_transformer import VisionTransformer
class AttentionWithRoPE(BaseModule):
"""Multi-head Attention Module with 2D sincos position embedding (RoPE).
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
attn_drop (float): Dropout rate of the dropout layer after the
attention calculation of query and key. Defaults to 0.
proj_drop (float): Dropout rate of the dropout layer after the
output projection. Defaults to 0.
qkv_bias (bool): If True, add a learnable bias to q and v. Note
that we follows the official implementation where ``k_bias``
is 0. Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
proj_bias (bool) If True, add a learnable bias to output projection.
Defaults to True.
rope (:obj:`torch.nn.Module`, optional): If it is an object of the
``RotaryEmbedding``, the rotation of the token position will be
performed before the softmax. Defaults to None.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
attn_drop=0.,
proj_drop=0.,
qkv_bias=True,
qk_scale=None,
proj_bias=True,
rope=None,
with_cls_token=True,
init_cfg=None):
super(AttentionWithRoPE, self).__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.num_heads = num_heads
self.head_dims = embed_dims // num_heads
self.scale = qk_scale or self.head_dims**-0.5
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.with_cls_token = with_cls_token
self.rope = rope
def forward(self, x, patch_resolution):
B, N, _ = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(dim=0)
if self.rope:
if self.with_cls_token:
q_t = q[:, :, 1:, :]
ro_q_t = self.rope(q_t, patch_resolution)
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
k_t = k[:, :, 1:, :] if self.with_cls_token else k
ro_k_t = self.rope(k_t, patch_resolution)
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
else:
q = self.rope(q, patch_resolution)
k = self.rope(k, patch_resolution)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1).type_as(x)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class EVA02EndcoderLayer(BaseModule):
"""Implements one encoder EVA02EndcoderLayer in EVA02.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
feedforward_channels (int): The hidden dimension of FFNs.
sub_ln (bool): Whether to add the sub layer normalization
in the attention module. Defaults to False.
attn_drop (float): Dropout rate of the dropout layer after the
attention calculation of query and key. Defaults to 0.
proj_drop (float): Dropout rate of the dropout layer after the
output projection. Defaults to 0.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
proj_bias (bool): enable bias for projection in the attention module
if True. Defaults to True.
rope (:obj:`torch.nn.Module`, optional): RotaryEmbedding object
in the attention module. Defaults to None.
drop_rate (float): Dropout rate in the mlp module. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
sub_ln=False,
attn_drop=0.,
proj_drop=0.,
qkv_bias=False,
qk_scale=None,
proj_bias=True,
rope=None,
with_cls_token=True,
drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
init_cfg=None):
super(EVA02EndcoderLayer, self).__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, embed_dims)
self.attn = AttentionWithRoPE(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop,
proj_drop=proj_drop,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
proj_bias=proj_bias,
rope=rope,
with_cls_token=with_cls_token)
self.drop_path = build_dropout(
dict(type='DropPath', drop_prob=drop_path_rate))
self.norm2 = build_norm_layer(norm_cfg, embed_dims)
if drop_rate > 0:
dropout_layer = dict(type='Dropout', drop_prob=drop_rate)
else:
dropout_layer = None
if sub_ln:
ffn_norm = norm_cfg
else:
ffn_norm = None
self.mlp = SwiGLUFFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
dropout_layer=dropout_layer,
norm_cfg=ffn_norm,
add_identity=False,
)
def forward(self, x, patch_resolution):
inputs = x
x = self.norm1(x)
x = self.attn(x, patch_resolution)
x = self.drop_path(x)
x = inputs + x
inputs = x
x = self.norm2(x)
x = self.mlp(x)
x = self.drop_path(x)
x = inputs + x
return x
@MODELS.register_module()
class ViTEVA02(VisionTransformer):
"""EVA02 Vision Transformer.
A PyTorch implement of : `EVA-02: A Visual Representation for Neon Genesis
<https://arxiv.org/abs/2303.11331>`_
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'tiny', 'small', 'base', 'large'. If use dict,
it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **mlp_ratio** (float): The ratio of the mlp module.
Defaults to 'tiny'.
sub_ln (bool): Whether to add the sub layer normalization in swiglu.
Defaults to False.
drop_rate (float): Probability of an element to be zeroed in the
mlp module. Defaults to 0.
attn_drop_rate (float): Probability of an element to be zeroed after
the softmax in the attention. Defaults to 0.
proj_drop_rate (float): Probability of an element to be zeroed after
projection in the attention. Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
**kwargs(dict, optional): Other args for Vision Transformer.
"""
arch_zoo = {
**dict.fromkeys(
['t', 'ti', 'tiny'], {
'embed_dims': 192,
'num_layers': 12,
'num_heads': 3,
'feedforward_channels': int(192 * 4 * 2 / 3)
}),
**dict.fromkeys(
['s', 'small'], {
'embed_dims': 384,
'num_layers': 12,
'num_heads': 6,
'feedforward_channels': int(384 * 4 * 2 / 3)
}),
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': int(768 * 4 * 2 / 3)
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 1024,
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': int(1024 * 4 * 2 / 3)
})
}
num_extra_tokens = 1 # class token
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
def __init__(self,
arch='tiny',
sub_ln=False,
drop_rate=0.,
attn_drop_rate=0.,
proj_drop_rate=0.,
drop_path_rate=0.,
qkv_bias=True,
norm_cfg=dict(type='LN'),
with_cls_token=True,
layer_cfgs=dict(),
**kwargs):
# set essential args for Vision Transformer
kwargs.update(
arch=arch,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
with_cls_token=with_cls_token)
super(ViTEVA02, self).__init__(**kwargs)
self.num_heads = self.arch_settings['num_heads']
# Set RoPE
head_dim = self.embed_dims // self.num_heads
self.rope = RotaryEmbeddingFast(
embed_dims=head_dim, patch_resolution=self.patch_resolution)
# stochastic depth decay rule
dpr = np.linspace(0, drop_path_rate, self.num_layers)
self.layers = ModuleList()
if isinstance(layer_cfgs, dict):
layer_cfgs = [layer_cfgs] * self.num_layers
for i in range(self.num_layers):
_layer_cfg = dict(
embed_dims=self.embed_dims,
num_heads=self.num_heads,
feedforward_channels=self.
arch_settings['feedforward_channels'],
sub_ln=sub_ln,
norm_cfg=norm_cfg,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_rate=drop_rate,
qkv_bias=qkv_bias,
rope=self.rope,
with_cls_token=with_cls_token,
drop_path_rate=dpr[i])
_layer_cfg.update(layer_cfgs[i])
self.layers.append(EVA02EndcoderLayer(**_layer_cfg))
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
if self.cls_token is not None:
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
x = self.pre_norm(x)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x, patch_resolution)
if i == len(self.layers) - 1 and self.final_norm:
x = self.ln1(x)
if i in self.out_indices:
outs.append(self._format_output(x, patch_resolution))
return tuple(outs)

View File

@ -18,7 +18,7 @@ from .layer_scale import LayerScale
from .make_divisible import make_divisible
from .norm import GRN, LayerNorm2d, build_norm_layer
from .position_encoding import (ConditionalPositionEncoding,
PositionEncodingFourier,
PositionEncodingFourier, RotaryEmbeddingFast,
build_2d_sincos_position_embedding)
from .res_layer_extra_norm import ResLayerExtraNorm
from .se_layer import SELayer
@ -72,4 +72,5 @@ __all__ = [
'ResLayerExtraNorm',
'SwiGLUFFN',
'SwiGLUFFNFused',
'RotaryEmbeddingFast',
]

View File

@ -8,6 +8,8 @@ import torch.nn as nn
from mmengine.model import BaseModule
from mmengine.utils import digit_version
from ..utils import to_2tuple
# After pytorch v1.10.0, use torch.meshgrid without indexing
# will raise extra warning. For more details,
# refers to https://github.com/pytorch/pytorch/issues/50276
@ -170,3 +172,76 @@ def build_2d_sincos_position_embedding(
pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1)
return pos_emb
class RotaryEmbeddingFast(BaseModule):
"""Implements 2D rotary embedding (RoPE) for image tokens. Position
encoding is implemented with sin and cos functions,
.. math::
Pos_{cos} = cos(\frac{t}{\theta^{\frac{2i}{d}}} \\
Pos_{sin} = sin(\frac{t}{\theta^{\frac{2i}{d}}}
Args:
embed_dims (int): The feature dimension for each head.
patch_resolution (int | tuple): The resolution of the
image, in format (H, W).
theta (float): The hyperparameter for position coding.
Defaults to 10000.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
embed_dims,
patch_resolution,
theta=10000.,
init_cfg=None):
super(RotaryEmbeddingFast, self).__init__(init_cfg=init_cfg)
self.half_dim = embed_dims // 2
self.patch_resolution = to_2tuple(patch_resolution)
self.theta = theta
freqs_cos, freqs_sin = self.compute_position_embedding()
self.register_buffer('freqs_cos', freqs_cos)
self.register_buffer('freqs_sin', freqs_sin)
def compute_position_embedding(self):
frequency = self.theta**(
torch.arange(0, self.half_dim, 2).float() / self.half_dim)
frequency = 1. / frequency
h, w = self.patch_resolution
th = torch.arange(h) / h * self.half_dim
tw = torch.arange(w) / w * self.half_dim
position_h = (th[:, None] @ frequency[None, :]).repeat(1, 2)
position_w = (tw[:, None] @ frequency[None, :]).repeat(1, 2)
height = position_h[:, None, :].expand(h, w, self.half_dim)
width = position_w[None, :, :].expand(h, w, self.half_dim)
position = torch.cat((height, width), dim=-1)
freqs_cos = position.cos().view(-1, position.shape[-1])
freqs_sin = position.sin().view(-1, position.shape[-1])
return freqs_cos, freqs_sin
def forward(self, x, patch_resolution):
# Check whether the patch resolution is the predefined size
patch_resolution = to_2tuple(patch_resolution)
if patch_resolution != self.patch_resolution:
self.patch_resolution = patch_resolution
freqs_cos, freqs_sin = self.compute_position_embedding()
self.register_buffer('freqs_cos', freqs_cos.to(x.device))
self.register_buffer('freqs_sin', freqs_sin.to(x.device))
batch, num_heads, num_patches, dim = x.shape
inputs = x
x = x.reshape(batch, num_heads, num_patches, -1, 2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
x = x.reshape(batch, num_heads, num_patches, dim)
return inputs * self.freqs_cos + x * self.freqs_sin

View File

@ -69,4 +69,5 @@ Import:
- configs/riformer/metafile.yml
- configs/sam/metafile.yml
- configs/glip/metafile.yml
- configs/eva02/metafile.yml
- configs/dinov2/metafile.yml

View File

@ -0,0 +1,143 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest import TestCase
import torch
from mmpretrain.models.backbones import ViTEVA02
class TestEVA02(TestCase):
def setUp(self):
self.cfg = dict(
arch='t',
img_size=336,
patch_size=14,
drop_path_rate=0.1,
drop_rate=0.1,
attn_drop_rate=0.2,
proj_drop_rate=0.3,
)
def test_structure(self):
# Test invalid default arch
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = 'unknown'
ViTEVA02(**cfg)
# Test invalid custom arch
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': int(24 * 4 * 2 / 3)
}
ViTEVA02(**cfg)
# Test custom arch
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'embed_dims': 128,
'num_layers': 6,
'num_heads': 16,
'feedforward_channels': int(128 * 4 * 2 / 3)
}
model = ViTEVA02(**cfg)
self.assertEqual(model.embed_dims, 128)
self.assertEqual(model.num_layers, 6)
for layer in model.layers:
self.assertEqual(layer.attn.num_heads, 16)
# Test out_indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = {1: 1}
with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
ViTEVA02(**cfg)
cfg['out_indices'] = [0, 13]
with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'):
ViTEVA02(**cfg)
# Test model structure
cfg = deepcopy(self.cfg)
model = ViTEVA02(**cfg)
self.assertEqual(len(model.layers), 12)
self.assertEqual(model.cls_token.shape, (1, 1, 192))
self.assertEqual(model.pos_embed.shape, (1, 577, 192))
dpr_inc = 0.1 / (12 - 1)
dpr = 0
for layer in model.layers:
self.assertEqual(layer.attn.embed_dims, 192)
self.assertEqual(layer.attn.num_heads, 3)
self.assertAlmostEqual(layer.drop_path.drop_prob, dpr)
self.assertAlmostEqual(layer.mlp.dropout_layer.p, 0.1)
self.assertAlmostEqual(layer.attn.attn_drop.p, 0.2)
self.assertAlmostEqual(layer.attn.proj_drop.p, 0.3)
dpr += dpr_inc
# Test model structure: final_norm
cfg = deepcopy(self.cfg)
cfg['final_norm'] = True
model = ViTEVA02(**cfg)
self.assertNotEqual(model.norm1.__class__, torch.nn.Identity)
def test_forward(self):
imgs = torch.randn(1, 3, 336, 336)
# test with_cls_token=False
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['out_type'] = 'cls_token'
with self.assertRaisesRegex(ValueError, 'must be True'):
ViTEVA02(**cfg)
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['out_type'] = 'raw'
model = ViTEVA02(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 24 * 24, 192))
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['out_type'] = 'featmap'
model = ViTEVA02(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 192, 24, 24))
cfg = deepcopy(self.cfg)
cfg['with_cls_token'] = False
cfg['out_type'] = 'avg_featmap'
model = ViTEVA02(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 192))
# test with output cls_token
cfg = deepcopy(self.cfg)
model = ViTEVA02(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
cls_token = outs[-1]
self.assertEqual(cls_token.shape, (1, 192))
# Test forward with multi out indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = [-3, -2, -1]
model = ViTEVA02(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 3)
for out in outs:
self.assertEqual(out.shape, (1, 192))

View File

@ -1,10 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmpretrain.models.utils import ConditionalPositionEncoding
from mmpretrain.models.utils import (ConditionalPositionEncoding,
RotaryEmbeddingFast)
def test_conditional_position_encoding_module():
CPE = ConditionalPositionEncoding(in_channels=32, embed_dims=32, stride=2)
outs = CPE(torch.randn(1, 3136, 32), (56, 56))
assert outs.shape == torch.Size([1, 784, 32])
def test_rotary_embedding_fast_module():
RoPE = RotaryEmbeddingFast(embed_dims=64, patch_resolution=24)
outs = RoPE(torch.randn(1, 2, 24 * 24, 64), (24, 24))
assert outs.shape == torch.Size([1, 2, 24 * 24, 64])
RoPE = RotaryEmbeddingFast(embed_dims=64, patch_resolution=(14, 20))
outs = RoPE(torch.randn(1, 2, 14 * 20, 64), (14, 20))
assert outs.shape == torch.Size([1, 2, 14 * 20, 64])

View File

@ -0,0 +1,153 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict
import mmengine
import torch
from mmengine.runner import CheckpointLoader
def convert_eva02(ckpt):
new_ckpt = OrderedDict()
qkv_proj = {}
qkv_bias = {}
w12_weight = {}
w12_bias = {}
banned = {
'mask_token',
'lm_head.weight',
'lm_head.bias',
'norm.weight',
'norm.bias',
}
for k, v in list(ckpt.items()):
if k in banned:
continue
if k.startswith('head'):
new_k = k.replace('head.', 'head.fc.')
new_ckpt[new_k] = v
else:
if k.startswith('patch_embed'):
new_k = k.replace('proj.', 'projection.')
elif k.startswith('fc_norm') or k.startswith('norm'):
new_k = k.replace('norm.', 'ln2.')
new_k = k.replace('fc_norm.', 'ln2.')
elif k.startswith('blocks'):
new_k = k.replace('blocks.', 'layers.')
if 'mlp' in new_k:
if 'w1.' in new_k or 'w2.' in new_k:
# For base and large version, mlp is implemented with
# 2 linears, where w1 and w2 are required to integrate
# into w12.
s = new_k.split('.') # e.g. layers.0.mlp.w1.weight
idx = s[1]
if 'weight' in new_k:
# w1.weight or w2.weight
if idx not in w12_weight:
w12_weight[idx] = {}
w12_weight[idx][s[-2]] = v
else:
# w1.bias or w2.bias
if idx not in w12_bias:
w12_bias[idx] = {}
w12_bias[idx][s[-2]] = v
continue
if 'ffn_ln' in new_k:
new_k = new_k.replace('ffn_ln.', 'norm.')
elif 'attn' in new_k:
if 'q_proj.weight' in new_k or \
'k_proj.weight' in new_k or \
'v_proj.weight' in new_k:
# For base and large version, qkv projection is
# implemented with three linear layers,
s = new_k.split('.')
idx = s[1]
if idx not in qkv_proj:
qkv_proj[idx] = {}
qkv_proj[idx][s[-2]] = v
continue
if 'q_bias' in new_k or 'v_bias' in new_k:
# k_bias is 0
s = new_k.split('.')
idx = s[1]
if idx not in qkv_bias:
qkv_bias[idx] = {}
qkv_bias[idx][s[-1]] = v
continue
else:
new_k = k
new_k = 'backbone.' + new_k
new_ckpt[new_k] = v
for idx in qkv_proj:
q_proj = qkv_proj[idx]['q_proj']
k_proj = qkv_proj[idx]['k_proj']
v_proj = qkv_proj[idx]['v_proj']
weight = torch.cat((q_proj, k_proj, v_proj))
new_k = f'backbone.layers.{idx}.attn.qkv.weight'
new_ckpt[new_k] = weight
for idx in qkv_bias:
q_bias = qkv_bias[idx]['q_bias']
k_bias = torch.zeros_like(q_bias)
v_bias = qkv_bias[idx]['v_bias']
weight = torch.cat((q_bias, k_bias, v_bias))
new_k = f'backbone.layers.{idx}.attn.qkv.bias'
new_ckpt[new_k] = weight
for idx in w12_weight:
w1 = w12_weight[idx]['w1']
w2 = w12_weight[idx]['w2']
weight = torch.cat((w1, w2))
new_k = f'backbone.layers.{idx}.mlp.w12.weight'
new_ckpt[new_k] = weight
for idx in w12_bias:
w1 = w12_bias[idx]['w1']
w2 = w12_bias[idx]['w2']
weight = torch.cat((w1, w2))
new_k = f'backbone.layers.{idx}.mlp.w12.bias'
new_ckpt[new_k] = weight
return new_ckpt
def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained eva02 '
'models to mmpretrain style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
if 'module' in checkpoint:
state_dict = checkpoint['module']
else:
state_dict = checkpoint
weight = convert_eva02(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)
print('Done!!')
if __name__ == '__main__':
main()