[Feature]: Add model zoo and fix lint

pull/728/head
liuyuan 2023-03-22 13:00:20 +08:00 committed by Yuan Liu
parent d5e737ff3e
commit aefea11df9
16 changed files with 72 additions and 28 deletions

View File

@ -157,6 +157,7 @@ Supported algorithms:
- [x] [BEiT v2 (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/beitv2)
- [x] [EVA (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/eva)
- [x] [MixMIM (ArXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/mixmim)
- [x] [PixMIM (ArXiv'2023)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/pixmim)
More algorithms are in our plan.

View File

@ -143,6 +143,7 @@ Useful Tools
- [x] [BEiT v2 (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/beitv2)
- [x] [EVA (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/eva)
- [x] [MixMIM (ArXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/mixmim)
- [x] [PixMIM (ArXiv'2023)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/pixmim)
更多的算法实现已经在我们的计划中。

View File

@ -1,9 +1,10 @@
# PixMIM
> [PixMIM: Rethinking Pixel Reconstruction in Masked Image Modeling
](https://arxiv.org/abs/2303.02416)
> ](https://arxiv.org/abs/2303.02416)
## TL;DR
## TL;DR
PixMIM can seamlessly replace MAE as a stronger baseline, with
negligible computational overhead.
@ -12,7 +13,7 @@ negligible computational overhead.
## Abstract
Masked Image Modeling (MIM) has achieved promising progress with the advent of Masked Autoencoders
(MAE) and BEiT. However, subsequent works have complicated the framework with new auxiliary tasks or extra pretrained models,
(MAE) and BEiT. However, subsequent works have complicated the framework with new auxiliary tasks or extra pretrained models,
inevitably increasing computational overhead. This paper undertakes a fundamental analysis of
MIM from the perspective of pixel reconstruction, which
examines the input image patches and reconstruction target, and highlights two critical but previously overlooked
@ -126,7 +127,6 @@ If you use a single machine without any cluster management software
GPUS=8 bash tools/benchmarks/classification/mim_dist_train.sh configs/selfsup/pixmim/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py $pretrained_model --amp
```
## Detection and Segmentation
If you want to evaluate your model on detection or segmentation task, we provide a [script](https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/tools/model_converters/mmcls2timm.py) to convert the model keys from MMClassification style to timm style.
@ -140,7 +140,6 @@ Then, using this converted ckpt, you can evaluate your model on detection task,
and on semantic segmentation task, following this [project](https://github.com/implus/mae_segmentation). Besides, using the unconverted ckpt, you can use
[MMSegmentation](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/mae) to evaluate your model.
## Citation
```bibtex

View File

@ -9,4 +9,4 @@ train_pipeline = [
dict(type='PackClsInputs'),
]
train_dataloader = dict(
batch_size=2048, dataset=dict(pipeline=train_pipeline), drop_last=True)
batch_size=2048, dataset=dict(pipeline=train_pipeline), drop_last=True)

View File

@ -44,7 +44,7 @@ Models:
Top 1 Accuracy: 63.3
Config: configs/selfsup/pixmim/classification/vit-base-p16_linear-8xb2048-coslr-torchvision-transform-90e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-72322af8.pth
- Name: pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k
In Collection: PixMIM
Metadata:
@ -75,4 +75,4 @@ Models:
Metrics:
Top 1 Accuracy: 67.5
Config: configs/selfsup/pixmim/classification/vit-base-p16_linear-8xb2048-coslr-torchvision-transform-90e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-12c15568.pth
Weights: https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-12c15568.pth

View File

@ -24,4 +24,4 @@ model = dict(
)
# randomness
randomness = dict(seed=2, diff_rank_seed=True)
randomness = dict(seed=2, diff_rank_seed=True)

View File

@ -19,4 +19,4 @@ param_scheduler = [
begin=40,
end=800,
convert_to_iter_based=True)
]
]

View File

@ -441,5 +441,26 @@ ImageNet has multiple versions, but the most commonly used one is ILSVRC 2012. T
<td>/</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/mixmim/classification/mixmim-base-p16_ft-8xb128-coslr-100e-in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221208-41ecada9.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221206_143046.json'>log</a></td>
</tr>
<tr>
<td rowspan="2">PixMIM</td>
<td>ViT-base</td>
<td>300</td>
<td>4096</td>
<td>63.3</td>
<td>83.1</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230322-3304a88c.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230322-3304a88c.json'> log </a></td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/classification/vit-base-p16_linear-8xb2048-coslr-torchvision-transform-90e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-72322af8.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-72322af8.json'> log </a></td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20230322-7eba2bc2.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20230322-7eba2bc2.json'> log </a></td>
</tr>
<tr>
<td>ViT-base</td>
<td>800</td>
<td>4096</td>
<td>67.5</td>
<td>83.5</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230322-e8137924.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230322-e8137924.json'> log </a></td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/classification/vit-base-p16_linear-8xb2048-coslr-torchvision-transform-90e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-12c15568.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-12c15568.json'> log </a></td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20230322-616b1a7f.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20230322-616b1a7f.json'> log </a></td>
</tr>
</tbody>
</table>

View File

@ -441,5 +441,26 @@ ImageNet 有多个版本,不过最常用的是 ILSVRC 2012。我们提供了
<td>/</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/mixmim/classification/mixmim-base-p16_ft-8xb128-coslr-100e-in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221208-41ecada9.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221206_143046.json'>log</a></td>
</tr>
<tr>
<td rowspan="2">PixMIM</td>
<td>ViT-base</td>
<td>300</td>
<td>4096</td>
<td>63.3</td>
<td>83.1</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230322-3304a88c.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k_20230322-3304a88c.json'> log </a></td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/classification/vit-base-p16_linear-8xb2048-coslr-torchvision-transform-90e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-72322af8.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-72322af8.json'> log </a></td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20230322-7eba2bc2.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-300e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20230322-7eba2bc2.json'> log </a></td>
</tr>
<tr>
<td>ViT-base</td>
<td>800</td>
<td>4096</td>
<td>67.5</td>
<td>83.5</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230322-e8137924.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k_20230322-e8137924.json'> log </a></td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/classification/vit-base-p16_linear-8xb2048-coslr-torchvision-transform-90e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-12c15568.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k/vit-base-p16_linear-8xb2048-torchvision-transform-coslr-90e_in1k_20230322-12c15568.json'> log </a></td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/1.x/configs/selfsup/pixmim/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py'> config </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20230322-616b1a7f.pth'> model </a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/pixmim/pixmim_vit-base-p16_8xb512-amp-coslr-800e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20230322-616b1a7f.json'> log </a></td>
</tr>
</tbody>
</table>

View File

@ -6,8 +6,8 @@ from .processing import (BEiTMaskGenerator, ColorJitter, RandomCrop,
RandomResizedCropAndInterpolationWithTwoPic,
RandomRotation, RandomSolarize, RotationWithLabels,
SimMIMMaskGenerator)
from .wrappers import MultiView
from .pytorch_transform import MAERandomResizedCrop
from .wrappers import MultiView
__all__ = [
'PackSelfSupInputs', 'RandomGaussianBlur', 'RandomSolarize',

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Tuple
import torch
import torchvision.transforms.functional as F
from torchvision import transforms
from typing import Tuple
import math
from mmselfsup.registry import TRANSFORMS
@ -40,8 +41,7 @@ class MAERandomResizedCrop(transforms.RandomResizedCrop):
return i, j, h, w
def forward(self, results: dict) -> dict:
"""
The forward function of MAERandomResizedCrop.
"""The forward function of MAERandomResizedCrop.
Args:
results (dict): The results dict contains the image and all these
@ -55,4 +55,4 @@ class MAERandomResizedCrop(transforms.RandomResizedCrop):
i, j, h, w = self.get_params(img, self.scale, self.ratio)
img = F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
results['img'] = img
return results
return results

View File

@ -15,13 +15,13 @@ from .moco import MoCo
from .mocov3 import MoCoV3
from .npid import NPID
from .odc import ODC
from .pixmim import PixMIM
from .relative_loc import RelativeLoc
from .rotation_pred import RotationPred
from .simclr import SimCLR
from .simmim import SimMIM
from .simsiam import SimSiam
from .swav import SwAV
from .pixmim import PixMIM
__all__ = [
'BaseModel', 'BarlowTwins', 'BEiT', 'BYOL', 'DeepCluster', 'DenseCL',

View File

@ -12,7 +12,7 @@ from .mae import MAE
class PixMIM(MAE):
"""The official implementation of PixMIM.
Implementation of `PixMIM: Rethinking Pixel Reconstruction in
Implementation of `PixMIM: Rethinking Pixel Reconstruction in
Masked Image Modeling <https://arxiv.org/pdf/2303.02416.pdf>`_.
Please refer to MAE for these initialization arguments.
@ -37,4 +37,4 @@ class PixMIM(MAE):
pred = self.neck(latent, ids_restore)
loss = self.head(pred, low_freq_targets, mask)
losses = dict(loss=loss)
return losses
return losses

View File

@ -2,8 +2,8 @@
from .clip_generator import CLIPGenerator
from .dall_e import Encoder
from .hog_generator import HOGGenerator
from .vqkd import VQKD
from .low_freq_generator import LowFreqTargetGenerator
from .vqkd import VQKD
__all__ = [
'HOGGenerator', 'VQKD', 'Encoder', 'CLIPGenerator',

View File

@ -11,13 +11,12 @@ from mmselfsup.registry import MODELS
class LowFreqTargetGenerator(BaseModule):
"""Generate low-frquency target for images.
This module is used in PixMIM: Rethinking Pixel Reconstruction in Masked
This module is used in PixMIM: Rethinking Pixel Reconstruction in Masked
Image Modeling to remove these high-frequency information from images.
Args:
radius (int): radius of low pass filter.
img_size (Union[int, Tuple[int, int]]): size of input images.
"""
def __init__(self, radius: int, img_size: Union[int, Tuple[int,
@ -55,7 +54,7 @@ class LowFreqTargetGenerator(BaseModule):
Args:
imgs (torch.Tensor): input images, which has shape (N, C, H, W).
Returns:
torch.Tensor: low frequency target, which has the same shape as
input images.
@ -77,4 +76,4 @@ class LowFreqTargetGenerator(BaseModule):
low_pass_imgs = (low_pass_imgs - mean) / std
return low_pass_imgs
return low_pass_imgs

View File

@ -1,7 +1,9 @@
import torch
from mmselfsup.datasets.transforms import MAERandomResizedCrop
from PIL import Image
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from PIL import Image
from mmselfsup.datasets.transforms import MAERandomResizedCrop
def test_mae_random_resized_crop():
@ -20,4 +22,4 @@ def test_mae_random_resized_crop():
assert list(results['img'].shape) == [224, 224, 3]
# test repr
assert isinstance(str(module), str)
assert isinstance(str(module), str)