[Feature]: Add MFF (#1725)

* [Feature]: Add MFF

* [Feature]: Add mff linear prob

* [Feature]: Add ft

* [Fix]: Update docstring

* [Feature]: Update out_indices

* [Feature]: Add prefix to ft

* [Feature]: Add README

* [Feature]: Update readme

* [Feature]: Update README

* [Feature]: Add metafile

* [Feature]: Update README

* [Fix]: Fix lint

* [Feature]: Add UT

* [Feature]: Update paper link
pull/1760/head
Yuan Liu 2023-08-08 16:01:07 +08:00 committed by GitHub
parent 2fb52eefdc
commit fa53174fd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 721 additions and 5 deletions

View File

@ -253,6 +253,7 @@ Results and models are available in the [model zoo](https://mmpretrain.readthedo
<li><a href="configs/mixmim">MixMIM (arXiv'2022)</a></li>
<li><a href="configs/itpn">iTPN (CVPR'2023)</a></li>
<li><a href="configs/spark">SparK (ICLR'2023)</a></li>
<li><a href="configs/mff">MFF (ICCV'2023)</a></li>
</ul>
</td>
<td>

View File

@ -249,6 +249,7 @@ mim install -e ".[multimodal]"
<li><a href="configs/mixmim">MixMIM (arXiv'2022)</a></li>
<li><a href="configs/itpn">iTPN (CVPR'2023)</a></li>
<li><a href="configs/spark">SparK (ICLR'2023)</a></li>
<li><a href="configs/mff">MFF (ICCV'2023)</a></li>
</ul>
</td>
<td>

View File

@ -0,0 +1,60 @@
# MFF
> [Improving Pixel-based MIM by Reducing Wasted Modeling Capability](https://arxiv.org/abs/2308.00261)
<!-- [ALGORITHM] -->
## Abstract
There has been significant progress in Masked Image Modeling (MIM). Existing MIM methods can be broadly categorized into two groups based on the reconstruction target: pixel-based and tokenizer-based approaches. The former offers a simpler pipeline and lower computational cost, but it is known to be biased toward high-frequency details. In this paper, we provide a set of empirical studies to confirm this limitation of pixel-based MIM and propose a new method that explicitly utilizes low-level features from shallow layers to aid pixel reconstruction. By incorporating this design into our base method, MAE, we reduce the wasted modeling capability of pixel-based MIM, improving its convergence and achieving non-trivial improvements across various downstream tasks. To the best of our knowledge, we are the first to systematically investigate multi-level feature fusion for isotropic architectures like the standard Vision Transformer (ViT). Notably, when applied to a smaller model (e.g., ViT-S), our method yields significant performance gains, such as 1.2% on fine-tuning, 2.8% on linear probing, and 2.6% on semantic segmentation.
<div align=center>
<img src="https://user-images.githubusercontent.com/30762564/257412932-5f36b11b-ee64-4ce7-b7d1-a31000302bd8.png" width="80%"/>
</div>
**Train/Test Command**
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
Train:
```shell
python tools/train.py configs/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py
```
Test:
```shell
python tools/test.py configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py None
```
<!-- [TABS-END] -->
## Models and results
### Pretrained models
| Model | Params (M) | Flops (G) | Config | Download |
| :-------------------------------------------- | :--------: | :-------: | :------------------------------------------------------: | :------------------------------------------------------------------------------: |
| `mff_vit-base-p16_8xb512-amp-coslr-300e_in1k` | - | - | [config](mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.json) |
| `mff_vit-base-p16_8xb512-amp-coslr-800e_in1k` | - | - | [config](mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.json) |
### Image Classification on ImageNet-1k
| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Config | Download |
| :---------------------------------------- | :------------------------------------------: | :--------: | :-------: | :-------: | :----------------------------------------: | :-------------------------------------------: |
| `vit-base-p16_mff-300e-pre_8xb128-coslr-100e_in1k` | [MFF 300-Epochs](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth) | 86.57 | 17.58 | 83.00 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.pth) / [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.json) |
| `vit-base-p16_mff-800e-pre_8xb128-coslr-100e_in1k` | [MFF 800-Epochs](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.pth) | 86.57 | 17.58 | 83.70 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb128-coslr-100e/vit-base-p16_8xb128-coslr-100e_20230802-6780e47d.pth) / [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb128-coslr-100e/vit-base-p16_8xb128-coslr-100e_20230802-6780e47d.json) |
| `vit-base-p16_mff-300e-pre_8xb2048-linear-coslr-90e_in1k` | [MFF 300-Epochs](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth) | 304.33 | 61.60 | 64.20 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb2048-linear-coslr-90e_in1k/vit-base-p16_8xb2048-linear-coslr-90e_in1k.json) |
| `vit-base-p16_mff-800e-pre_8xb2048-linear-coslr-90e_in1k` | [MFF 800-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k_20220825-f7569ca2.pth) | 304.33 | 61.60 | 68.30 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb2048-linear-coslr-90e/vit-base-p16_8xb2048-linear-coslr-90e_20230802-6b1f7bc8.pth) / [log](https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb2048-linear-coslr-90e/vit-base-p16_8xb2048-linear-coslr-90e_20230802-6b1f7bc8.json) |
## Citation
```bibtex
@article{MFF,
title={Improving Pixel-based MIM by Reducing Wasted Modeling Capability},
author={Yuan Liu, Songyang Zhang, Jiacheng Chen, Zhaohui Yu, Kai Chen, Dahua Lin},
journal={arXiv},
year={2023}
}
```

View File

@ -0,0 +1,114 @@
_base_ = [
'../../_base_/datasets/imagenet_bs64_swin_224.py',
'../../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../../_base_/default_runtime.py'
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=0.3333333333333333,
fill_color=[103.53, 116.28, 123.675],
fill_std=[57.375, 57.12, 58.395]),
dict(type='PackInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs')
]
train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='base',
img_size=224,
patch_size=16,
drop_path_rate=0.1,
out_type='avg_featmap',
final_norm=False,
init_cfg=dict(type='Pretrained', checkpoint='', prefix='backbone.')),
neck=None,
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)]),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))
# optimizer wrapper
optim_wrapper = dict(
optimizer=dict(
type='AdamW', lr=2e-3, weight_decay=0.05, betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.65,
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=5,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=95,
by_epoch=True,
begin=5,
end=100,
eta_min=1e-6,
convert_to_iter_based=True)
]
# runtime settings
default_hooks = dict(
# save checkpoint per epoch.
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
train_cfg = dict(by_epoch=True, max_epochs=100)
randomness = dict(seed=0, diff_rank_seed=True)

View File

@ -0,0 +1,74 @@
_base_ = [
'../../_base_/datasets/imagenet_bs32_pil_resize.py',
'../../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../../_base_/default_runtime.py'
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ToPIL', to_rgb=True),
dict(type='MAERandomResizedCrop', size=224, interpolation=3),
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
dict(type='ToNumpy', to_bgr=True),
dict(type='PackInputs'),
]
# dataset settings
train_dataloader = dict(
batch_size=2048, drop_last=True, dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(drop_last=False)
test_dataloader = dict(drop_last=False)
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='base',
img_size=224,
patch_size=16,
frozen_stages=12,
out_type='cls_token',
final_norm=True,
init_cfg=dict(type='Pretrained', prefix='backbone.')),
neck=dict(type='ClsBatchNormNeck', input_features=768),
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
in_channels=768,
loss=dict(type='CrossEntropyLoss'),
init_cfg=[dict(type='TruncNormal', layer='Linear', std=0.01)]))
# optimizer
optim_wrapper = dict(
_delete_=True,
type='AmpOptimWrapper',
optimizer=dict(type='LARS', lr=6.4, weight_decay=0.0, momentum=0.9))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=80,
by_epoch=True,
begin=10,
end=90,
eta_min=0.0,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(by_epoch=True, max_epochs=90)
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
logger=dict(type='LoggerHook', interval=10))
randomness = dict(seed=0, diff_rank_seed=True)

View File

@ -0,0 +1,103 @@
Collections:
- Name: MFF
Metadata:
Training Data: ImageNet-1k
Training Techniques:
- AdamW
Training Resources: 8x A100-80G GPUs
Architecture:
- ViT
Paper:
Title: Improving Pixel-based MIM by Reducing Wasted Modeling Capability
URL: https://arxiv.org/pdf/2308.00261.pdf
README: configs/mff/README.md
Models:
- Name: mff_vit-base-p16_8xb512-amp-coslr-300e_in1k
Metadata:
Epochs: 300
Batch Size: 2048
FLOPs: 17581972224
Parameters: 85882692
Training Data: ImageNet-1k
In Collection: MaskFeat
Results: null
Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230801-3c1bcce4.pth
Config: configs/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k.py
Downstream:
- vit-base-p16_mff-300e-pre_8xb128-coslr-100e_in1k
- vit-base-p16_mff-300e-pre_8xb2048-linear-coslr-90e_in1k
- Name: mff_vit-base-p16_8xb512-amp-coslr-800e_in1k
Metadata:
Epochs: 800
Batch Size: 2048
FLOPs: 17581972224
Parameters: 85882692
Training Data: ImageNet-1k
In Collection: MaskFeat
Results: null
Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230801-3af7cd9d.pth
Config: configs/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k.py
Downstream:
- vit-base-p16_mff-800e-pre_8xb128-coslr-100e_in1k
- vit-base-p16_mff-800e-pre_8xb2048-linear-coslr-90e_in1k
- Name: vit-base-p16_mff-300e-pre_8xb128-coslr-100e_in1k
Metadata:
Epochs: 100
Batch Size: 1024
FLOPs: 17581215744
Parameters: 86566120
Training Data: ImageNet-1k
In Collection: MaskFeat
Results:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.0
Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.pth
Config: configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py
- Name: vit-base-p16_mff-800e-pre_8xb128-coslr-100e_in1k
Metadata:
Epochs: 100
Batch Size: 1024
FLOPs: 17581215744
Parameters: 86566120
Training Data: ImageNet-1k
In Collection: MFF
Results:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.7
Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_8xb128-coslr-100e/vit-base-p16_8xb128-coslr-100e_20230802-6780e47d.pth
Config: configs/mff/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py
- Name: vit-base-p16_mff-300e-pre_8xb2048-linear-coslr-90e_in1k
Metadata:
Epochs: 90
Batch Size: 16384
FLOPs: 17581215744
Parameters: 86566120
Training Data: ImageNet-1k
In Collection: MFF
Results:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 64.2
Weights:
Config: configs/mff/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py
- Name: vit-base-p16_mff-800e-pre_8xb2048-linear-coslr-90e_in1k
Metadata:
Epochs: 90
Batch Size: 16384
FLOPs: 17581215744
Parameters: 86566120
Training Data: ImageNet-1k
In Collection: MFF
Results:
- Task: Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 68.3
Weights: https://download.openmmlab.com/mmpretrain/v1.0/mff/mff_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_8xb128-coslr-100e_in1k/vit-base-p16_8xb128-coslr-100e_in1k_20230802-d746fdb7.pth
Config: configs/mff/benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py

View File

@ -0,0 +1,24 @@
_base_ = '../mae/mae_vit-base-p16_8xb512-amp-coslr-300e_in1k.py'
randomness = dict(seed=2, diff_rank_seed=True)
# dataset config
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ToPIL', to_rgb=True),
dict(type='torchvision/Resize', size=224),
dict(
type='torchvision/RandomCrop',
size=224,
padding=4,
padding_mode='reflect'),
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
dict(type='ToNumpy', to_bgr=True),
dict(type='PackInputs')
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
# model config
model = dict(
type='MFF', backbone=dict(type='MFFViT', out_indices=[0, 2, 4, 6, 8, 11]))

View File

@ -0,0 +1,24 @@
_base_ = '../mae/mae_vit-base-p16_8xb512-amp-coslr-800e_in1k.py'
randomness = dict(seed=2, diff_rank_seed=True)
# dataset config
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ToPIL', to_rgb=True),
dict(type='torchvision/Resize', size=224),
dict(
type='torchvision/RandomCrop',
size=224,
padding=4,
padding_mode='reflect'),
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
dict(type='ToNumpy', to_bgr=True),
dict(type='PackInputs')
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
# model config
model = dict(
type='MFF', backbone=dict(type='MFFViT', out_indices=[0, 2, 4, 6, 8, 11]))

View File

@ -41,8 +41,8 @@ if WITH_MULTIMODAL:
from .flickr30k_caption import Flickr30kCaption
from .flickr30k_retrieval import Flickr30kRetrieval
from .gqa_dataset import GQA
from .infographic_vqa import InfographicVQA
from .iconqa import IconQA
from .infographic_vqa import InfographicVQA
from .nocaps import NoCaps
from .ocr_vqa import OCRVQA
from .refcoco import RefCOCO

View File

@ -12,8 +12,9 @@ from .formatting import (Collect, NumpyToPIL, PackInputs, PackMultiTaskInputs,
PILToNumpy, Transpose)
from .processing import (Albumentations, BEiTMaskGenerator, CleanCaption,
ColorJitter, EfficientNetCenterCrop,
EfficientNetRandomCrop, Lighting, RandomCrop,
RandomErasing, RandomResizedCrop,
EfficientNetRandomCrop, Lighting,
MAERandomResizedCrop, RandomCrop, RandomErasing,
RandomResizedCrop,
RandomResizedCropAndInterpolationWithTwoPic,
RandomTranslatePad, ResizeEdge, SimMIMMaskGenerator)
from .utils import get_transform_idx, remove_transform
@ -36,5 +37,5 @@ __all__ = [
'RandomFlip', 'RandomGrayscale', 'RandomResize', 'Resize', 'MultiView',
'ApplyToList', 'CleanCaption', 'RandomTranslatePad',
'RandomResizedCropAndInterpolationWithTwoPic', 'get_transform_idx',
'remove_transform'
'remove_transform', 'MAERandomResizedCrop'
]

View File

@ -11,9 +11,13 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
import mmcv
import mmengine
import numpy as np
import torch
import torchvision
import torchvision.transforms.functional as F
from mmcv.transforms import BaseTransform
from mmcv.transforms.utils import cache_randomness
from PIL import Image
from torchvision import transforms
from torchvision.transforms.transforms import InterpolationMode
from mmpretrain.registry import TRANSFORMS
@ -1740,3 +1744,52 @@ class RandomTranslatePad(BaseTransform):
results['gt_bboxes'][i] = box
return results
@TRANSFORMS.register_module()
class MAERandomResizedCrop(transforms.RandomResizedCrop):
"""RandomResizedCrop for matching TF/TPU implementation: no for-loop is
used.
This may lead to results different with torchvision's version.
Following BYOL's TF code:
https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 # noqa: E501
"""
@staticmethod
def get_params(img: Image.Image, scale: tuple, ratio: tuple) -> Tuple:
width, height = img.size
area = height * width
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
log_ratio = torch.log(torch.tensor(ratio))
aspect_ratio = torch.exp(
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
w = min(w, width)
h = min(h, height)
i = torch.randint(0, height - h + 1, size=(1, )).item()
j = torch.randint(0, width - w + 1, size=(1, )).item()
return i, j, h, w
def forward(self, results: dict) -> dict:
"""The forward function of MAERandomResizedCrop.
Args:
results (dict): The results dict contains the image and all these
information related to the image.
Returns:
dict: The results dict contains the cropped image and all these
information related to the image.
"""
img = results['img']
i, j, h, w = self.get_params(img, self.scale, self.ratio)
img = F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
results['img'] = img
return results

View File

@ -436,7 +436,7 @@ class VisionTransformer(BaseBackbone):
for param in self.pre_norm.parameters():
param.requires_grad = False
# freeze cls_token
if self.cls_token:
if self.cls_token is not None:
self.cls_token.requires_grad = False
# freeze layers
for i in range(1, self.frozen_stages + 1):

View File

@ -9,6 +9,7 @@ from .eva import EVA
from .itpn import iTPN, iTPNHiViT
from .mae import MAE, MAEHiViT, MAEViT
from .maskfeat import HOGGenerator, MaskFeat, MaskFeatViT
from .mff import MFF, MFFViT
from .milan import MILAN, CLIPGenerator, MILANViT
from .mixmim import MixMIM, MixMIMPretrainTransformer
from .moco import MoCo
@ -53,4 +54,6 @@ __all__ = [
'BarlowTwins',
'SwAV',
'SparK',
'MFF',
'MFFViT',
]

View File

@ -0,0 +1,194 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
from mmpretrain.models.selfsup.mae import MAE, MAEViT
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
@MODELS.register_module()
class MFFViT(MAEViT):
"""Vision Transformer for MFF Pretraining.
This class inherits all these functionalities from ``MAEViT``, and
add multi-level feature fusion to it. For more details, you can
refer to `Improving Pixel-based MIM by Reducing Wasted Modeling
Capability`.
Args:
arch (str | dict): Vision Transformer architecture
Default: 'b'
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
It only works without input mask. Defaults to ``"avg_featmap"``.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
mask_ratio (bool): The ratio of total number of patches to be masked.
Defaults to 0.75.
init_cfg (Union[List[dict], dict], optional): Initialization config
dict. Defaults to None.
"""
def __init__(self,
arch: Union[str, dict] = 'b',
img_size: int = 224,
patch_size: int = 16,
out_indices: Union[Sequence, int] = -1,
drop_rate: float = 0,
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
out_type: str = 'raw',
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),
mask_ratio: float = 0.75,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
arch=arch,
img_size=img_size,
patch_size=patch_size,
out_indices=out_indices,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
final_norm=final_norm,
out_type=out_type,
interpolate_mode=interpolate_mode,
patch_cfg=patch_cfg,
layer_cfgs=layer_cfgs,
mask_ratio=mask_ratio,
init_cfg=init_cfg)
proj_layers = [
torch.nn.Linear(self.embed_dims, self.embed_dims)
for _ in range(len(self.out_indices) - 1)
]
self.proj_layers = torch.nn.ModuleList(proj_layers)
self.proj_weights = torch.nn.Parameter(
torch.ones(len(self.out_indices)).view(-1, 1, 1, 1))
if len(self.out_indices) == 1:
self.proj_weights.requires_grad = False
def forward(
self,
x: torch.Tensor,
mask: Optional[bool] = True
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate features for masked images.
The function supports two kind of forward behaviors. If the ``mask`` is
``True``, the function will generate mask to masking some patches
randomly and get the hidden features for visible patches, which means
the function will be executed as masked imagemodeling pre-training;
if the ``mask`` is ``None`` or ``False``, the forward function will
call ``super().forward()``, which extract features from images without
mask.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
mask (bool, optional): To indicate whether the forward function
generating ``mask`` or not.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
mask and the ids to restore original image.
- ``x`` (torch.Tensor): hidden features, which is of shape
B x (L * mask_ratio) x C.
- ``mask`` (torch.Tensor): mask used to mask image.
- ``ids_restore`` (torch.Tensor): ids to restore original image.
"""
if mask is None or False:
return super().forward(x)
else:
B = x.shape[0]
x = self.patch_embed(x)[0]
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
res = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i in self.out_indices:
if i != self.out_indices[-1]:
proj_x = self.proj_layers[self.out_indices.index(i)](x)
else:
proj_x = x
res.append(proj_x)
res = torch.stack(res)
proj_weights = F.softmax(self.proj_weights, dim=0)
res = res * proj_weights
res = res.sum(dim=0)
# Use final norm
x = self.norm1(res)
return (x, mask, ids_restore, proj_weights.view(-1))
@MODELS.register_module()
class MFF(MAE):
"""MFF.
Implementation of `Improving Pixel-based MIM by Reducing Wasted Modeling
Capability`.
"""
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (torch.Tensor): The input images.
data_samples (List[DataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
# ids_restore: the same as that in original repo, which is used
# to recover the original order of tokens in decoder.
latent, mask, ids_restore, weights = self.backbone(inputs)
pred = self.neck(latent, ids_restore)
loss = self.head.loss(pred, inputs, mask)
weight_params = {
f'weight_{i}': weights[i]
for i in range(weights.size(0))
}
losses = dict(loss=loss)
losses.update(weight_params)
return losses

View File

@ -82,3 +82,4 @@ Import:
- configs/minigpt4/metafile.yml
- configs/llava/metafile.yml
- configs/otter/metafile.yml
- configs/mff/metafile.yml

View File

@ -0,0 +1,63 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmpretrain.models import MFF, MFFViT
from mmpretrain.structures import DataSample
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_mae_vit():
backbone = dict(
arch='b', patch_size=16, mask_ratio=0.75, out_indices=[1, 11])
mae_backbone = MFFViT(**backbone)
mae_backbone.init_weights()
fake_inputs = torch.randn((2, 3, 224, 224))
# test with mask
fake_outputs = mae_backbone(fake_inputs)[0]
assert list(fake_outputs.shape) == [2, 50, 768]
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_mae():
data_preprocessor = {
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'to_rgb': True
}
backbone = dict(
type='MFFViT',
arch='b',
patch_size=16,
mask_ratio=0.75,
out_indices=[1, 11])
neck = dict(
type='MAEPretrainDecoder',
patch_size=16,
in_chans=3,
embed_dim=768,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4.,
)
loss = dict(type='PixelReconstructionLoss', criterion='L2')
head = dict(
type='MAEPretrainHead', norm_pix=False, patch_size=16, loss=loss)
alg = MFF(
backbone=backbone,
neck=neck,
head=head,
data_preprocessor=data_preprocessor)
fake_data = {
'inputs': torch.randn((2, 3, 224, 224)),
'data_samples': [DataSample() for _ in range(2)]
}
fake_inputs = alg.data_preprocessor(fake_data)
fake_outputs = alg(**fake_inputs, mode='loss')
assert isinstance(fake_outputs['loss'].item(), float)