[Enhance] Add metafile, readme and converted models for MLP-Mixer (#539)

* add pth converter

* minor update on config files, add metafile and readme

* add missing readme and minor fixes

* minor fixes

* Update config names and checkpoint download link

* Update model_zoo.md

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/555/head
Zhicheng Chen 2021-11-24 19:04:19 +08:00 committed by GitHub
parent f3fbc8b90b
commit fc8adbc149
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 215 additions and 15 deletions

View File

@ -0,0 +1,48 @@
# dataset settings
dataset_type = 'ImageNet'
# change according to https://github.com/rwightman/pytorch-image-models/blob
# /master/timm/models/mlp_mixer.py
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
# training is not supported for now
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=224, backend='cv2'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize', size=(256, -1), backend='cv2', interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=64,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_prefix='data/imagenet/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline),
test=dict(
# replace `data/val` with `data/test` for standard test
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline))
evaluation = dict(interval=10, metric='accuracy')

View File

@ -3,9 +3,9 @@ model = dict(
type='ImageClassifier',
backbone=dict(
type='MlpMixer',
arch='b',
arch='l',
img_size=224,
patch_size=32,
patch_size=16,
drop_rate=0.1,
init_cfg=[
dict(
@ -18,7 +18,7 @@ model = dict(
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
in_channels=1024,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
),

View File

@ -0,0 +1,37 @@
# MLP-Mixer: An all-MLP Architecture for Vision
<!-- {Mlp-Mixer} -->
<!-- [ALGORITHM] -->
## Abstract
<!-- [ABSTRACT] -->
Convolutional Neural Networks (CNNs) are the go-to model for computer vision. Recently, attention-based networks, such as the Vision Transformer, have also become popular. In this paper we show that while convolutions and attention are both sufficient for good performance, neither of them are necessary. We present MLP-Mixer, an architecture based exclusively on multi-layer perceptrons (MLPs). MLP-Mixer contains two types of layers: one with MLPs applied independently to image patches (i.e. "mixing" the per-location features), and one with MLPs applied across patches (i.e. "mixing" spatial information). When trained on large datasets, or with modern regularization schemes, MLP-Mixer attains competitive scores on image classification benchmarks, with pre-training and inference cost comparable to state-of-the-art models. We hope that these results spark further research beyond the realms of well established CNNs and Transformers.
<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/26739999/143178327-7118b48a-5f5f-4844-a614-a571917384ca.png" width="90%"/>
</div>
## Citation
```latex
@misc{tolstikhin2021mlpmixer,
title={MLP-Mixer: An all-MLP Architecture for Vision},
author={Ilya Tolstikhin and Neil Houlsby and Alexander Kolesnikov and Lucas Beyer and Xiaohua Zhai and Thomas Unterthiner and Jessica Yung and Andreas Steiner and Daniel Keysers and Jakob Uszkoreit and Mario Lucic and Alexey Dosovitskiy},
year={2021},
eprint={2105.01601},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
## Pretrain model
The pre-trained modles are converted from [timm](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/mlp_mixer.py).
### ImageNet-1k
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:--------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth)|
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth)|
*Models with \* are converted from other repos.*

View File

@ -0,0 +1,50 @@
Collections:
- Name: MLP-Mixer
Metadata:
Training Data: ImageNet-1k
Architecture:
- MLP
- Layer Normalization
- Dropout
Paper:
URL: https://arxiv.org/abs/2105.01601
Title: "MLP-Mixer: An all-MLP Architecture for Vision"
README: configs/mlp_mixer/README.md
# Code:
# URL: # todo
# Version: # todo
Models:
- Name: mlp-mixer-base-p16_3rdparty_64xb64_in1k
In Collection: MLP-Mixer
Config: configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py
Metadata:
FLOPs: 12610000000 # 12.61 G
Parameters: 59880000 # 59.88 M
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 76.68
Top 5 Accuracy: 92.25
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth
Converted From:
Weights: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth
Code: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/mlp_mixer.py#L70
- Name: mlp-mixer-large-p16_3rdparty_64xb64_in1k
In Collection: MLP-Mixer
Config: configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py
Metadata:
FLOPs: 44570000000 # 44.57 G
Parameters: 208200000 # 208.2 M
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 72.34
Top 5 Accuracy: 88.02
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth
Converted From:
Weights: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth
Code: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/mlp_mixer.py#L73

View File

@ -1,6 +1,6 @@
_base_ = [
'../_base_/models/mlp_mixer_base_patch16.py',
'../_base_/datasets/imagenet_bs64_pil_resize.py',
'../_base_/datasets/imagenet_bs64_mixer_224.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/default_runtime.py',
]

View File

@ -0,0 +1,6 @@
_base_ = [
'../_base_/models/mlp_mixer_large_patch16.py',
'../_base_/datasets/imagenet_bs64_mixer_224.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/default_runtime.py',
]

View File

@ -1,6 +0,0 @@
_base_ = [
'../_base_/models/mlp_mixer_base_patch32.py',
'../_base_/datasets/imagenet_bs64_pil_resize.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/default_runtime.py',
]

View File

@ -61,6 +61,9 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| T2T-ViT_t-14\* | 21.47 | 4.34 | 81.69 | 95.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-420df0f6.pth) &#124; [log]()|
| T2T-ViT_t-19\* | 39.08 | 7.80 | 82.43 | 96.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-e479c2a6.pth) &#124; [log]()|
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-b5bf2526.pth) &#124; [log]()|
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) &#124; [log]()|
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) &#124; [log]()|
Models with * are converted from other repos, others are trained by ourselves.

View File

@ -12,7 +12,10 @@ from .base_backbone import BaseBackbone
class MixerBlock(BaseModule):
"""Implements mixer block in MLP Mixer.
"""Mlp-Mixer basic block.
Basic module of `MLP-Mixer: An all-MLP Architecture for Vision
<https://arxiv.org/pdf/2105.01601.pdf>`_
Args:
num_tokens (int): The number of patched tokens
@ -96,13 +99,14 @@ class MixerBlock(BaseModule):
@BACKBONES.register_module()
class MlpMixer(BaseBackbone):
"""Mlp Mixer.
"""Mlp-Mixer backbone.
Pytorch implementation of `MLP-Mixer: An all-MLP Architecture for Vision
<https://arxiv.org/pdf/2105.01601.pdf>`_
A PyTorch implement of : `MLP-Mixer: An all-MLP Architecture for Vision` -
https://arxiv.org/abs/2105.01601
Args:
arch (str | dict): MLP Mixer architecture
Default: 'b'.
Defaults to 'b'.
img_size (int | tuple): Input image size.
patch_size (int | tuple): The patch size.
out_indices (Sequence | int): Output from which layer.

View File

@ -12,3 +12,4 @@ Import:
- configs/tnt/metafile.yml
- configs/vision_transformer/metafile.yml
- configs/t2t_vit/metafile.yml
- configs/mlp_mixer/metafile.yml

View File

@ -0,0 +1,57 @@
import argparse
from pathlib import Path
import torch
def convert_weights(weight):
"""Weight Converter.
Converts the weights from timm to mmcls
Args:
weight (dict): weight dict from timm
Returns: converted weight dict for mmcls
"""
result = dict()
result['meta'] = dict()
temp = dict()
mapping = {
'stem': 'patch_embed',
'proj': 'projection',
'mlp_tokens.fc1': 'token_mix.layers.0.0',
'mlp_tokens.fc2': 'token_mix.layers.1',
'mlp_channels.fc1': 'channel_mix.layers.0.0',
'mlp_channels.fc2': 'channel_mix.layers.1',
'norm1': 'ln1',
'norm2': 'ln2',
'norm.': 'ln1.',
'blocks': 'layers'
}
for k, v in weight.items():
for mk, mv in mapping.items():
if mk in k:
k = k.replace(mk, mv)
if k.startswith('head.'):
temp['head.fc.' + k[5:]] = v
else:
temp['backbone.' + k] = v
result['state_dict'] = temp
return result
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert model keys')
parser.add_argument('src', help='src detectron model path')
parser.add_argument('dst', help='save path')
args = parser.parse_args()
dst = Path(args.dst)
if dst.suffix != '.pth':
print('The path should contain the name of the pth format file.')
exit()
dst.parent.mkdir(parents=True, exist_ok=True)
original_model = torch.load(args.src, map_location='cpu')
converted_model = convert_weights(original_model)
torch.save(converted_model, args.dst)