[Feature]: Add MFF (#1725)
* [Feature]: Add MFF * [Feature]: Add mff linear prob * [Feature]: Add ft * [Fix]: Update docstring * [Feature]: Update out_indices * [Feature]: Add prefix to ft * [Feature]: Add README * [Feature]: Update readme * [Feature]: Update README * [Feature]: Add metafile * [Feature]: Update README * [Fix]: Fix lint * [Feature]: Add UT * [Feature]: Update paper linkpull/1760/head
parent
2fb52eefdc
commit
fa53174fd9
|
@ -253,6 +253,7 @@ Results and models are available in the [model zoo](https://mmpretrain.readthedo
|
|||
<li><a href="configs/mixmim">MixMIM (arXiv'2022)</a></li>
|
||||
<li><a href="configs/itpn">iTPN (CVPR'2023)</a></li>
|
||||
<li><a href="configs/spark">SparK (ICLR'2023)</a></li>
|
||||
<li><a href="configs/mff">MFF (ICCV'2023)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
|
|
|
@ -249,6 +249,7 @@ mim install -e ".[multimodal]"
|
|||
<li><a href="configs/mixmim">MixMIM (arXiv'2022)</a></li>
|
||||
<li><a href="configs/itpn">iTPN (CVPR'2023)</a></li>
|
||||
<li><a href="configs/spark">SparK (ICLR'2023)</a></li>
|
||||
<li><a href="configs/mff">MFF (ICCV'2023)</a></li>
|
||||
</ul>
|
||||
</td>
|
||||
<td>
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# MFF
|
||||
|
||||
> [Improving Pixel-based MIM by Reducing Wasted Modeling Capability](https://arxiv.org/abs/2308.00261)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
There has been significant progress in Masked Image Modeling (MIM). Existing MIM methods can be broadly categorized into two groups based on the reconstruction target: pixel-based and tokenizer-based approaches. The former offers a simpler pipeline and lower computational cost, but it is known to be biased toward high-frequency details. In this paper, we provide a set of empirical studies to confirm this limitation of pixel-based MIM and propose a new method that explicitly utilizes low-level features from shallow layers to aid pixel reconstruction. By incorporating this design into our base method, MAE, we reduce the wasted modeling capability of pixel-based MIM, improving its convergence and achieving non-trivial improvements across various downstream tasks. To the best of our knowledge, we are the first to systematically investigate multi-level feature fusion for isotropic architectures like the standard Vision Transformer (ViT). Notably, when applied to a smaller model (e.g., ViT-S), our method yields significant performance gains, such as 1.2% on fine-tuning, 2.8% on linear probing, and 2.6% on semantic segmentation.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/30762564/257412932-5f36b11b-ee64-4ce7-b7d1-a31000302bd8.png" width="80%"/>
|
||||
</div>
|
||||
|
||||
**Train/Test Command**
|
||||
|
||||
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
|
||||
|
||||
Train:
|
||||
|
||||
```shell
|
||||
python tools/train.py configs/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py
|
||||
```
|
||||
|
||||
Test:
|
||||
|
||||
```shell
|
||||
python tools/test.py configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py None
|
||||
```
|
||||
|
||||
<!-- [TABS-END] -->
|
||||
|
||||
## Models and results
|
||||
|
||||
### Pretrained models
|
||||
|
||||
| Model | Params (M) | Flops (G) | Config | Download |
|
||||
| :-------------------------------------------- | :--------: | :-------: | :------------------------------------------------------: | :------------------------------------------------------------------------------: |
|
||||
| `mff_vit-base-p16_8xb512-amp-coslr-300e_in1k` | - | - | [config](mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.json) |
|
||||
| `mff_vit-base-p16_8xb512-amp-coslr-800e_in1k` | - | - | [config](mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.json) |
|
||||
|
||||
### Image Classification on ImageNet-1k
|
||||
|
||||
| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Config | Download |
|
||||
| :---------------------------------------- | :------------------------------------------: | :--------: | :-------: | :-------: | :----------------------------------------: | :-------------------------------------------: |
|
||||
| `vit-base-p16_mff-300e-pre_8xb128-coslr-100e_in1k` | [MFF 300-Epochs](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth) | 86.57 | 17.58 | 83.00 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.pth) / [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.json) |
|
||||
| `vit-base-p16_mff-800e-pre_8xb128-coslr-100e_in1k` | [MFF 800-Epochs](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.pth) | 86.57 | 17.58 | 83.70 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb128-coslr-100e/vit-base-p16_8xb128-coslr-100e_20230802-6780e47d.pth) / [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb128-coslr-100e/vit-base-p16_8xb128-coslr-100e_20230802-6780e47d.json) |
|
||||
| `vit-base-p16_mff-300e-pre_8xb2048-linear-coslr-90e_in1k` | [MFF 300-Epochs](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth) | 304.33 | 61.60 | 64.20 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb2048-linear-coslr-90e_in1k/vit-base-p16_8xb2048-linear-coslr-90e_in1k.json) |
|
||||
| `vit-base-p16_mff-800e-pre_8xb2048-linear-coslr-90e_in1k` | [MFF 800-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k_20220825-f7569ca2.pth) | 304.33 | 61.60 | 68.30 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb2048-linear-coslr-90e/vit-base-p16_8xb2048-linear-coslr-90e_20230802-6b1f7bc8.pth) / [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb2048-linear-coslr-90e/vit-base-p16_8xb2048-linear-coslr-90e_20230802-6b1f7bc8.json) |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{MFF,
|
||||
title={Improving Pixel-based MIM by Reducing Wasted Modeling Capability},
|
||||
author={Yuan Liu, Songyang Zhang, Jiacheng Chen, Zhaohui Yu, Kai Chen, Dahua Lin},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,114 @@
|
|||
_base_ = [
|
||||
'../../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=0.3333333333333333,
|
||||
fill_color=[103.53, 116.28, 123.675],
|
||||
fill_std=[57.375, 57.12, 58.395]),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=256,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline))
|
||||
val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
drop_path_rate=0.1,
|
||||
out_type='avg_featmap',
|
||||
final_norm=False,
|
||||
init_cfg=dict(type='Pretrained', checkpoint='', prefix='backbone.')),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)]),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=2e-3, weight_decay=0.05, betas=(0.9, 0.999)),
|
||||
constructor='LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
layer_decay_rate=0.65,
|
||||
custom_keys={
|
||||
'.ln': dict(decay_mult=0.0),
|
||||
'.bias': dict(decay_mult=0.0),
|
||||
'.cls_token': dict(decay_mult=0.0),
|
||||
'.pos_embed': dict(decay_mult=0.0)
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=5,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=95,
|
||||
by_epoch=True,
|
||||
begin=5,
|
||||
end=100,
|
||||
eta_min=1e-6,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
default_hooks = dict(
|
||||
# save checkpoint per epoch.
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
train_cfg = dict(by_epoch=True, max_epochs=100)
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
|
@ -0,0 +1,74 @@
|
|||
_base_ = [
|
||||
'../../_base_/datasets/imagenet_bs32_pil_resize.py',
|
||||
'../../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='ToPIL', to_rgb=True),
|
||||
dict(type='MAERandomResizedCrop', size=224, interpolation=3),
|
||||
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
|
||||
dict(type='ToNumpy', to_bgr=True),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
train_dataloader = dict(
|
||||
batch_size=2048, drop_last=True, dataset=dict(pipeline=train_pipeline))
|
||||
val_dataloader = dict(drop_last=False)
|
||||
test_dataloader = dict(drop_last=False)
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
frozen_stages=12,
|
||||
out_type='cls_token',
|
||||
final_norm=True,
|
||||
init_cfg=dict(type='Pretrained', prefix='backbone.')),
|
||||
neck=dict(type='ClsBatchNormNeck', input_features=768),
|
||||
head=dict(
|
||||
type='VisionTransformerClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(type='CrossEntropyLoss'),
|
||||
init_cfg=[dict(type='TruncNormal', layer='Linear', std=0.01)]))
|
||||
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='AmpOptimWrapper',
|
||||
optimizer=dict(type='LARS', lr=6.4, weight_decay=0.0, momentum=0.9))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=10,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=80,
|
||||
by_epoch=True,
|
||||
begin=10,
|
||||
end=90,
|
||||
eta_min=0.0,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(by_epoch=True, max_epochs=90)
|
||||
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
|
||||
logger=dict(type='LoggerHook', interval=10))
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
|
@ -0,0 +1,103 @@
|
|||
Collections:
|
||||
- Name: MFF
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
Training Techniques:
|
||||
- AdamW
|
||||
Training Resources: 8x A100-80G GPUs
|
||||
Architecture:
|
||||
- ViT
|
||||
Paper:
|
||||
Title: Improving Pixel-based MIM by Reducing Wasted Modeling Capability
|
||||
URL: https://arxiv.org/pdf/2308.00261.pdf
|
||||
README: configs/mff/README.md
|
||||
|
||||
Models:
|
||||
- Name: mff_vit-base-p16_8xb512-amp-coslr-300e_in1k
|
||||
Metadata:
|
||||
Epochs: 300
|
||||
Batch Size: 2048
|
||||
FLOPs: 17581972224
|
||||
Parameters: 85882692
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: MaskFeat
|
||||
Results: null
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth
|
||||
Config: configs/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py
|
||||
Downstream:
|
||||
- vit-base-p16_mff-300e-pre_8xb128-coslr-100e_in1k
|
||||
- vit-base-p16_mff-300e-pre_8xb2048-linear-coslr-90e_in1k
|
||||
- Name: mff_vit-base-p16_8xb512-amp-coslr-800e_in1k
|
||||
Metadata:
|
||||
Epochs: 800
|
||||
Batch Size: 2048
|
||||
FLOPs: 17581972224
|
||||
Parameters: 85882692
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: MaskFeat
|
||||
Results: null
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.pth
|
||||
Config: configs/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k.py
|
||||
Downstream:
|
||||
- vit-base-p16_mff-800e-pre_8xb128-coslr-100e_in1k
|
||||
- vit-base-p16_mff-800e-pre_8xb2048-linear-coslr-90e_in1k
|
||||
- Name: vit-base-p16_mff-300e-pre_8xb128-coslr-100e_in1k
|
||||
Metadata:
|
||||
Epochs: 100
|
||||
Batch Size: 1024
|
||||
FLOPs: 17581215744
|
||||
Parameters: 86566120
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: MaskFeat
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.0
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.pth
|
||||
Config: configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py
|
||||
- Name: vit-base-p16_mff-800e-pre_8xb128-coslr-100e_in1k
|
||||
Metadata:
|
||||
Epochs: 100
|
||||
Batch Size: 1024
|
||||
FLOPs: 17581215744
|
||||
Parameters: 86566120
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: MFF
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.7
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb128-coslr-100e/vit-base-p16_8xb128-coslr-100e_20230802-6780e47d.pth
|
||||
Config: configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py
|
||||
- Name: vit-base-p16_mff-300e-pre_8xb2048-linear-coslr-90e_in1k
|
||||
Metadata:
|
||||
Epochs: 90
|
||||
Batch Size: 16384
|
||||
FLOPs: 17581215744
|
||||
Parameters: 86566120
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: MFF
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 64.2
|
||||
Weights:
|
||||
Config: configs/mff/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py
|
||||
- Name: vit-base-p16_mff-800e-pre_8xb2048-linear-coslr-90e_in1k
|
||||
Metadata:
|
||||
Epochs: 90
|
||||
Batch Size: 16384
|
||||
FLOPs: 17581215744
|
||||
Parameters: 86566120
|
||||
Training Data: ImageNet-1k
|
||||
In Collection: MFF
|
||||
Results:
|
||||
- Task: Image Classification
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 68.3
|
||||
Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.pth
|
||||
Config: configs/mff/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py
|
|
@ -0,0 +1,24 @@
|
|||
_base_ = '../mae/mae_vit-base-p16_8xb512-amp-coslr-300e_in1k.py'
|
||||
|
||||
randomness = dict(seed=2, diff_rank_seed=True)
|
||||
|
||||
# dataset config
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='ToPIL', to_rgb=True),
|
||||
dict(type='torchvision/Resize', size=224),
|
||||
dict(
|
||||
type='torchvision/RandomCrop',
|
||||
size=224,
|
||||
padding=4,
|
||||
padding_mode='reflect'),
|
||||
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
|
||||
dict(type='ToNumpy', to_bgr=True),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
|
||||
# model config
|
||||
model = dict(
|
||||
type='MFF', backbone=dict(type='MFFViT', out_indices=[0, 2, 4, 6, 8, 11]))
|
|
@ -0,0 +1,24 @@
|
|||
_base_ = '../mae/mae_vit-base-p16_8xb512-amp-coslr-800e_in1k.py'
|
||||
|
||||
randomness = dict(seed=2, diff_rank_seed=True)
|
||||
|
||||
# dataset config
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='ToPIL', to_rgb=True),
|
||||
dict(type='torchvision/Resize', size=224),
|
||||
dict(
|
||||
type='torchvision/RandomCrop',
|
||||
size=224,
|
||||
padding=4,
|
||||
padding_mode='reflect'),
|
||||
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
|
||||
dict(type='ToNumpy', to_bgr=True),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
|
||||
# model config
|
||||
model = dict(
|
||||
type='MFF', backbone=dict(type='MFFViT', out_indices=[0, 2, 4, 6, 8, 11]))
|
|
@ -41,8 +41,8 @@ if WITH_MULTIMODAL:
|
|||
from .flickr30k_caption import Flickr30kCaption
|
||||
from .flickr30k_retrieval import Flickr30kRetrieval
|
||||
from .gqa_dataset import GQA
|
||||
from .infographic_vqa import InfographicVQA
|
||||
from .iconqa import IconQA
|
||||
from .infographic_vqa import InfographicVQA
|
||||
from .nocaps import NoCaps
|
||||
from .ocr_vqa import OCRVQA
|
||||
from .refcoco import RefCOCO
|
||||
|
|
|
@ -12,8 +12,9 @@ from .formatting import (Collect, NumpyToPIL, PackInputs, PackMultiTaskInputs,
|
|||
PILToNumpy, Transpose)
|
||||
from .processing import (Albumentations, BEiTMaskGenerator, CleanCaption,
|
||||
ColorJitter, EfficientNetCenterCrop,
|
||||
EfficientNetRandomCrop, Lighting, RandomCrop,
|
||||
RandomErasing, RandomResizedCrop,
|
||||
EfficientNetRandomCrop, Lighting,
|
||||
MAERandomResizedCrop, RandomCrop, RandomErasing,
|
||||
RandomResizedCrop,
|
||||
RandomResizedCropAndInterpolationWithTwoPic,
|
||||
RandomTranslatePad, ResizeEdge, SimMIMMaskGenerator)
|
||||
from .utils import get_transform_idx, remove_transform
|
||||
|
@ -36,5 +37,5 @@ __all__ = [
|
|||
'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView',
|
||||
'ApplyToList', 'CleanCaption', 'RandomTranslatePad',
|
||||
'RandomResizedCropAndInterpolationWithTwoPic', 'get_transform_idx',
|
||||
'remove_transform'
|
||||
'remove_transform', 'MAERandomResizedCrop'
|
||||
]
|
||||
|
|
|
@ -11,9 +11,13 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms.functional as F
|
||||
from mmcv.transforms import BaseTransform
|
||||
from mmcv.transforms.utils import cache_randomness
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.transforms import InterpolationMode
|
||||
|
||||
from mmpretrain.registry import TRANSFORMS
|
||||
|
@ -1740,3 +1744,52 @@ class RandomTranslatePad(BaseTransform):
|
|||
results['gt_bboxes'][i] = box
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class MAERandomResizedCrop(transforms.RandomResizedCrop):
|
||||
"""RandomResizedCrop for matching TF/TPU implementation: no for-loop is
|
||||
used.
|
||||
|
||||
This may lead to results different with torchvision's version.
|
||||
Following BYOL's TF code:
|
||||
https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 # noqa: E501
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_params(img: Image.Image, scale: tuple, ratio: tuple) -> Tuple:
|
||||
width, height = img.size
|
||||
area = height * width
|
||||
|
||||
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
|
||||
log_ratio = torch.log(torch.tensor(ratio))
|
||||
aspect_ratio = torch.exp(
|
||||
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
|
||||
|
||||
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
w = min(w, width)
|
||||
h = min(h, height)
|
||||
|
||||
i = torch.randint(0, height - h + 1, size=(1, )).item()
|
||||
j = torch.randint(0, width - w + 1, size=(1, )).item()
|
||||
|
||||
return i, j, h, w
|
||||
|
||||
def forward(self, results: dict) -> dict:
|
||||
"""The forward function of MAERandomResizedCrop.
|
||||
|
||||
Args:
|
||||
results (dict): The results dict contains the image and all these
|
||||
information related to the image.
|
||||
|
||||
Returns:
|
||||
dict: The results dict contains the cropped image and all these
|
||||
information related to the image.
|
||||
"""
|
||||
img = results['img']
|
||||
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
||||
img = F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
|
||||
results['img'] = img
|
||||
return results
|
||||
|
|
|
@ -436,7 +436,7 @@ class VisionTransformer(BaseBackbone):
|
|||
for param in self.pre_norm.parameters():
|
||||
param.requires_grad = False
|
||||
# freeze cls_token
|
||||
if self.cls_token:
|
||||
if self.cls_token is not None:
|
||||
self.cls_token.requires_grad = False
|
||||
# freeze layers
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
|
|
|
@ -9,6 +9,7 @@ from .eva import EVA
|
|||
from .itpn import iTPN, iTPNHiViT
|
||||
from .mae import MAE, MAEHiViT, MAEViT
|
||||
from .maskfeat import HOGGenerator, MaskFeat, MaskFeatViT
|
||||
from .mff import MFF, MFFViT
|
||||
from .milan import MILAN, CLIPGenerator, MILANViT
|
||||
from .mixmim import MixMIM, MixMIMPretrainTransformer
|
||||
from .moco import MoCo
|
||||
|
@ -53,4 +54,6 @@ __all__ = [
|
|||
'BarlowTwins',
|
||||
'SwAV',
|
||||
'SparK',
|
||||
'MFF',
|
||||
'MFFViT',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,194 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmpretrain.models.selfsup.mae import MAE, MAEViT
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MFFViT(MAEViT):
|
||||
"""Vision Transformer for MFF Pretraining.
|
||||
|
||||
This class inherits all these functionalities from ``MAEViT``, and
|
||||
add multi-level feature fusion to it. For more details, you can
|
||||
refer to `Improving Pixel-based MIM by Reducing Wasted Modeling
|
||||
Capability`.
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture
|
||||
Default: 'b'
|
||||
img_size (int | tuple): Input image size
|
||||
patch_size (int | tuple): The patch size
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
out_type (str): The type of output features. Please choose from
|
||||
|
||||
- ``"cls_token"``: The class token tensor with shape (B, C).
|
||||
- ``"featmap"``: The feature map tensor from the patch tokens
|
||||
with shape (B, C, H, W).
|
||||
- ``"avg_featmap"``: The global averaged feature map tensor
|
||||
with shape (B, C).
|
||||
- ``"raw"``: The raw feature tensor includes patch tokens and
|
||||
class tokens with shape (B, L, C).
|
||||
|
||||
It only works without input mask. Defaults to ``"avg_featmap"``.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||
encoder. Defaults to an empty dict.
|
||||
mask_ratio (bool): The ratio of total number of patches to be masked.
|
||||
Defaults to 0.75.
|
||||
init_cfg (Union[List[dict], dict], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch: Union[str, dict] = 'b',
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
out_indices: Union[Sequence, int] = -1,
|
||||
drop_rate: float = 0,
|
||||
drop_path_rate: float = 0,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
final_norm: bool = True,
|
||||
out_type: str = 'raw',
|
||||
interpolate_mode: str = 'bicubic',
|
||||
patch_cfg: dict = dict(),
|
||||
layer_cfgs: dict = dict(),
|
||||
mask_ratio: float = 0.75,
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(
|
||||
arch=arch,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
out_indices=out_indices,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_cfg=norm_cfg,
|
||||
final_norm=final_norm,
|
||||
out_type=out_type,
|
||||
interpolate_mode=interpolate_mode,
|
||||
patch_cfg=patch_cfg,
|
||||
layer_cfgs=layer_cfgs,
|
||||
mask_ratio=mask_ratio,
|
||||
init_cfg=init_cfg)
|
||||
proj_layers = [
|
||||
torch.nn.Linear(self.embed_dims, self.embed_dims)
|
||||
for _ in range(len(self.out_indices) - 1)
|
||||
]
|
||||
self.proj_layers = torch.nn.ModuleList(proj_layers)
|
||||
self.proj_weights = torch.nn.Parameter(
|
||||
torch.ones(len(self.out_indices)).view(-1, 1, 1, 1))
|
||||
if len(self.out_indices) == 1:
|
||||
self.proj_weights.requires_grad = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[bool] = True
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Generate features for masked images.
|
||||
|
||||
The function supports two kind of forward behaviors. If the ``mask`` is
|
||||
``True``, the function will generate mask to masking some patches
|
||||
randomly and get the hidden features for visible patches, which means
|
||||
the function will be executed as masked imagemodeling pre-training;
|
||||
if the ``mask`` is ``None`` or ``False``, the forward function will
|
||||
call ``super().forward()``, which extract features from images without
|
||||
mask.
|
||||
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images, which is of shape B x C x H x W.
|
||||
mask (bool, optional): To indicate whether the forward function
|
||||
generating ``mask`` or not.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
|
||||
mask and the ids to restore original image.
|
||||
|
||||
- ``x`` (torch.Tensor): hidden features, which is of shape
|
||||
B x (L * mask_ratio) x C.
|
||||
- ``mask`` (torch.Tensor): mask used to mask image.
|
||||
- ``ids_restore`` (torch.Tensor): ids to restore original image.
|
||||
"""
|
||||
if mask is None or False:
|
||||
return super().forward(x)
|
||||
|
||||
else:
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)[0]
|
||||
# add pos embed w/o cls token
|
||||
x = x + self.pos_embed[:, 1:, :]
|
||||
|
||||
# masking: length -> length * mask_ratio
|
||||
x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
|
||||
|
||||
# append cls token
|
||||
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
||||
cls_tokens = cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
res = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
if i != self.out_indices[-1]:
|
||||
proj_x = self.proj_layers[self.out_indices.index(i)](x)
|
||||
else:
|
||||
proj_x = x
|
||||
res.append(proj_x)
|
||||
res = torch.stack(res)
|
||||
proj_weights = F.softmax(self.proj_weights, dim=0)
|
||||
res = res * proj_weights
|
||||
res = res.sum(dim=0)
|
||||
|
||||
# Use final norm
|
||||
x = self.norm1(res)
|
||||
return (x, mask, ids_restore, proj_weights.view(-1))
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MFF(MAE):
|
||||
"""MFF.
|
||||
|
||||
Implementation of `Improving Pixel-based MIM by Reducing Wasted Modeling
|
||||
Capability`.
|
||||
"""
|
||||
|
||||
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""The forward function in training.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input images.
|
||||
data_samples (List[DataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
# ids_restore: the same as that in original repo, which is used
|
||||
# to recover the original order of tokens in decoder.
|
||||
latent, mask, ids_restore, weights = self.backbone(inputs)
|
||||
pred = self.neck(latent, ids_restore)
|
||||
loss = self.head.loss(pred, inputs, mask)
|
||||
weight_params = {
|
||||
f'weight_{i}': weights[i]
|
||||
for i in range(weights.size(0))
|
||||
}
|
||||
losses = dict(loss=loss)
|
||||
losses.update(weight_params)
|
||||
return losses
|
|
@ -82,3 +82,4 @@ Import:
|
|||
- configs/minigpt4/metafile.yml
|
||||
- configs/llava/metafile.yml
|
||||
- configs/otter/metafile.yml
|
||||
- configs/mff/metafile.yml
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmpretrain.models import MFF, MFFViT
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_mae_vit():
|
||||
backbone = dict(
|
||||
arch='b', patch_size=16, mask_ratio=0.75, out_indices=[1, 11])
|
||||
mae_backbone = MFFViT(**backbone)
|
||||
mae_backbone.init_weights()
|
||||
fake_inputs = torch.randn((2, 3, 224, 224))
|
||||
|
||||
# test with mask
|
||||
fake_outputs = mae_backbone(fake_inputs)[0]
|
||||
assert list(fake_outputs.shape) == [2, 50, 768]
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_mae():
|
||||
data_preprocessor = {
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'to_rgb': True
|
||||
}
|
||||
backbone = dict(
|
||||
type='MFFViT',
|
||||
arch='b',
|
||||
patch_size=16,
|
||||
mask_ratio=0.75,
|
||||
out_indices=[1, 11])
|
||||
neck = dict(
|
||||
type='MAEPretrainDecoder',
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
decoder_embed_dim=512,
|
||||
decoder_depth=8,
|
||||
decoder_num_heads=16,
|
||||
mlp_ratio=4.,
|
||||
)
|
||||
loss = dict(type='PixelReconstructionLoss', criterion='L2')
|
||||
head = dict(
|
||||
type='MAEPretrainHead', norm_pix=False, patch_size=16, loss=loss)
|
||||
|
||||
alg = MFF(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
data_preprocessor=data_preprocessor)
|
||||
|
||||
fake_data = {
|
||||
'inputs': torch.randn((2, 3, 224, 224)),
|
||||
'data_samples': [DataSample() for _ in range(2)]
|
||||
}
|
||||
fake_inputs = alg.data_preprocessor(fake_data)
|
||||
fake_outputs = alg(**fake_inputs, mode='loss')
|
||||
assert isinstance(fake_outputs['loss'].item(), float)
|
Loading…
Reference in New Issue