[Enhance] Add metafile, readme and converted models for MLP-Mixer (#539)
* add pth converter * minor update on config files, add metafile and readme * add missing readme and minor fixes * minor fixes * Update config names and checkpoint download link * Update model_zoo.md Co-authored-by: mzr1996 <mzr1996@163.com>pull/555/head
parent
f3fbc8b90b
commit
fc8adbc149
|
@ -0,0 +1,48 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
|
||||
# change according to https://github.com/rwightman/pytorch-image-models/blob
|
||||
# /master/timm/models/mlp_mixer.py
|
||||
img_norm_cfg = dict(
|
||||
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
|
||||
|
||||
# training is not supported for now
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', size=224, backend='cv2'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize', size=(256, -1), backend='cv2', interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=8,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
# replace `data/val` with `data/test` for standard test
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=10, metric='accuracy')
|
|
@ -3,9 +3,9 @@ model = dict(
|
|||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='MlpMixer',
|
||||
arch='b',
|
||||
arch='l',
|
||||
img_size=224,
|
||||
patch_size=32,
|
||||
patch_size=16,
|
||||
drop_rate=0.1,
|
||||
init_cfg=[
|
||||
dict(
|
||||
|
@ -18,7 +18,7 @@ model = dict(
|
|||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
in_channels=1024,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
),
|
|
@ -0,0 +1,37 @@
|
|||
# MLP-Mixer: An all-MLP Architecture for Vision
|
||||
<!-- {Mlp-Mixer} -->
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
<!-- [ABSTRACT] -->
|
||||
Convolutional Neural Networks (CNNs) are the go-to model for computer vision. Recently, attention-based networks, such as the Vision Transformer, have also become popular. In this paper we show that while convolutions and attention are both sufficient for good performance, neither of them are necessary. We present MLP-Mixer, an architecture based exclusively on multi-layer perceptrons (MLPs). MLP-Mixer contains two types of layers: one with MLPs applied independently to image patches (i.e. "mixing" the per-location features), and one with MLPs applied across patches (i.e. "mixing" spatial information). When trained on large datasets, or with modern regularization schemes, MLP-Mixer attains competitive scores on image classification benchmarks, with pre-training and inference cost comparable to state-of-the-art models. We hope that these results spark further research beyond the realms of well established CNNs and Transformers.
|
||||
|
||||
<!-- [IMAGE] -->
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/26739999/143178327-7118b48a-5f5f-4844-a614-a571917384ca.png" width="90%"/>
|
||||
</div>
|
||||
|
||||
## Citation
|
||||
```latex
|
||||
@misc{tolstikhin2021mlpmixer,
|
||||
title={MLP-Mixer: An all-MLP Architecture for Vision},
|
||||
author={Ilya Tolstikhin and Neil Houlsby and Alexander Kolesnikov and Lucas Beyer and Xiaohua Zhai and Thomas Unterthiner and Jessica Yung and Andreas Steiner and Daniel Keysers and Jakob Uszkoreit and Mario Lucic and Alexey Dosovitskiy},
|
||||
year={2021},
|
||||
eprint={2105.01601},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
```
|
||||
|
||||
## Pretrain model
|
||||
|
||||
The pre-trained modles are converted from [timm](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/mlp_mixer.py).
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
|:--------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
|
||||
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth)|
|
||||
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth)|
|
||||
|
||||
*Models with \* are converted from other repos.*
|
|
@ -0,0 +1,50 @@
|
|||
Collections:
|
||||
- Name: MLP-Mixer
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
Architecture:
|
||||
- MLP
|
||||
- Layer Normalization
|
||||
- Dropout
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/2105.01601
|
||||
Title: "MLP-Mixer: An all-MLP Architecture for Vision"
|
||||
README: configs/mlp_mixer/README.md
|
||||
# Code:
|
||||
# URL: # todo
|
||||
# Version: # todo
|
||||
|
||||
Models:
|
||||
- Name: mlp-mixer-base-p16_3rdparty_64xb64_in1k
|
||||
In Collection: MLP-Mixer
|
||||
Config: configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py
|
||||
Metadata:
|
||||
FLOPs: 12610000000 # 12.61 G
|
||||
Parameters: 59880000 # 59.88 M
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 76.68
|
||||
Top 5 Accuracy: 92.25
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth
|
||||
Converted From:
|
||||
Weights: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth
|
||||
Code: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/mlp_mixer.py#L70
|
||||
|
||||
- Name: mlp-mixer-large-p16_3rdparty_64xb64_in1k
|
||||
In Collection: MLP-Mixer
|
||||
Config: configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py
|
||||
Metadata:
|
||||
FLOPs: 44570000000 # 44.57 G
|
||||
Parameters: 208200000 # 208.2 M
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 72.34
|
||||
Top 5 Accuracy: 88.02
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth
|
||||
Converted From:
|
||||
Weights: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth
|
||||
Code: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/mlp_mixer.py#L73
|
|
@ -1,6 +1,6 @@
|
|||
_base_ = [
|
||||
'../_base_/models/mlp_mixer_base_patch16.py',
|
||||
'../_base_/datasets/imagenet_bs64_pil_resize.py',
|
||||
'../_base_/datasets/imagenet_bs64_mixer_224.py',
|
||||
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = [
|
||||
'../_base_/models/mlp_mixer_large_patch16.py',
|
||||
'../_base_/datasets/imagenet_bs64_mixer_224.py',
|
||||
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
|
@ -1,6 +0,0 @@
|
|||
_base_ = [
|
||||
'../_base_/models/mlp_mixer_base_patch32.py',
|
||||
'../_base_/datasets/imagenet_bs64_pil_resize.py',
|
||||
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
|
@ -61,6 +61,9 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
|||
| T2T-ViT_t-14\* | 21.47 | 4.34 | 81.69 | 95.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-420df0f6.pth) | [log]()|
|
||||
| T2T-ViT_t-19\* | 39.08 | 7.80 | 82.43 | 96.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-e479c2a6.pth) | [log]()|
|
||||
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-b5bf2526.pth) | [log]()|
|
||||
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) | [log]()|
|
||||
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) | [log]()|
|
||||
|
||||
|
||||
Models with * are converted from other repos, others are trained by ourselves.
|
||||
|
||||
|
|
|
@ -12,7 +12,10 @@ from .base_backbone import BaseBackbone
|
|||
|
||||
|
||||
class MixerBlock(BaseModule):
|
||||
"""Implements mixer block in MLP Mixer.
|
||||
"""Mlp-Mixer basic block.
|
||||
|
||||
Basic module of `MLP-Mixer: An all-MLP Architecture for Vision
|
||||
<https://arxiv.org/pdf/2105.01601.pdf>`_
|
||||
|
||||
Args:
|
||||
num_tokens (int): The number of patched tokens
|
||||
|
@ -96,13 +99,14 @@ class MixerBlock(BaseModule):
|
|||
|
||||
@BACKBONES.register_module()
|
||||
class MlpMixer(BaseBackbone):
|
||||
"""Mlp Mixer.
|
||||
"""Mlp-Mixer backbone.
|
||||
|
||||
Pytorch implementation of `MLP-Mixer: An all-MLP Architecture for Vision
|
||||
<https://arxiv.org/pdf/2105.01601.pdf>`_
|
||||
|
||||
A PyTorch implement of : `MLP-Mixer: An all-MLP Architecture for Vision` -
|
||||
https://arxiv.org/abs/2105.01601
|
||||
Args:
|
||||
arch (str | dict): MLP Mixer architecture
|
||||
Default: 'b'.
|
||||
Defaults to 'b'.
|
||||
img_size (int | tuple): Input image size.
|
||||
patch_size (int | tuple): The patch size.
|
||||
out_indices (Sequence | int): Output from which layer.
|
||||
|
|
|
@ -12,3 +12,4 @@ Import:
|
|||
- configs/tnt/metafile.yml
|
||||
- configs/vision_transformer/metafile.yml
|
||||
- configs/t2t_vit/metafile.yml
|
||||
- configs/mlp_mixer/metafile.yml
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def convert_weights(weight):
|
||||
"""Weight Converter.
|
||||
|
||||
Converts the weights from timm to mmcls
|
||||
|
||||
Args:
|
||||
weight (dict): weight dict from timm
|
||||
|
||||
Returns: converted weight dict for mmcls
|
||||
"""
|
||||
result = dict()
|
||||
result['meta'] = dict()
|
||||
temp = dict()
|
||||
mapping = {
|
||||
'stem': 'patch_embed',
|
||||
'proj': 'projection',
|
||||
'mlp_tokens.fc1': 'token_mix.layers.0.0',
|
||||
'mlp_tokens.fc2': 'token_mix.layers.1',
|
||||
'mlp_channels.fc1': 'channel_mix.layers.0.0',
|
||||
'mlp_channels.fc2': 'channel_mix.layers.1',
|
||||
'norm1': 'ln1',
|
||||
'norm2': 'ln2',
|
||||
'norm.': 'ln1.',
|
||||
'blocks': 'layers'
|
||||
}
|
||||
for k, v in weight.items():
|
||||
for mk, mv in mapping.items():
|
||||
if mk in k:
|
||||
k = k.replace(mk, mv)
|
||||
if k.startswith('head.'):
|
||||
temp['head.fc.' + k[5:]] = v
|
||||
else:
|
||||
temp['backbone.' + k] = v
|
||||
result['state_dict'] = temp
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Convert model keys')
|
||||
parser.add_argument('src', help='src detectron model path')
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
dst = Path(args.dst)
|
||||
if dst.suffix != '.pth':
|
||||
print('The path should contain the name of the pth format file.')
|
||||
exit()
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
original_model = torch.load(args.src, map_location='cpu')
|
||||
converted_model = convert_weights(original_model)
|
||||
torch.save(converted_model, args.dst)
|
Loading…
Reference in New Issue