[Feature] Support MixMIM Pretrain and Finetuning (#626)

* [Feature] Support MixMIM Pretrain and Finetuning

* [Feature] Fix Lint

* [Feature] Add doc string for MixMIMTransformerPretrain

* [Feature] Add doc string for MixMIMPretrainHead

* [Feature] Add README

* [Feature] Fix Config

* [Feature] Fix Config

* [Feature] Add doc string to mixmim neck and bacbone

* [Feature] Add doc string to mixmim neck and bacbone

* [Feature] Support MixMIM Pretrain and Finetuning

* [Feature] Fix Lint

* [Feature] Support MixMIM Pretrain and Finetuning

* [Feature] Fix Lint

* [Feature] Support MixMIM Pretrain and Finetuning

* [Feature] Fix Lint

* [Feature] Support MixMIM Pretrain and Finetuning

* [Feature] Fix Lint

* [Feature] Replace MixMIMTransformer with import from mmcls

* add an explanation of the lr

* add an explanation of the lr

* [Feature] Fix lint

* [Feature] Modification after Review

* [Feature] Modification after Review2

* [Feature] Modification after Review2

* [Feature] Modification after Review2

* [Feature] Modification after Review3

* [Feature] Fix lint

Co-authored-by: WasedaMagina <33023171+WasedaMagina@users.noreply.github.com>
pull/681/head
Wangbo Zhao(黑色枷锁) 2022-12-30 22:11:53 +08:00 committed by GitHub
parent 304e81650a
commit a08faa1e11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 840 additions and 9 deletions

View File

@ -138,6 +138,7 @@ Supported algorithms:
- [x] [MILAN (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/milan)
- [x] [BEiT v2 (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/beitv2)
- [x] [EVA (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/eva)
- [x] [MixMIM (ArXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/mixmim)
More algorithms are in our plan.

View File

@ -138,6 +138,7 @@ Useful Tools
- [x] [MILAN (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/milan)
- [x] [BEiT v2 (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/beitv2)
- [x] [EVA (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/eva)
- [x] [MixMIM (ArXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/mixmim)
更多的算法实现已经在我们的计划中。

View File

@ -0,0 +1,30 @@
# dataset settings
dataset_type = 'mmcls.ImageNet'
data_root = 'data/imagenet/'
file_client_args = dict(backend='disk')
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='RandomResizedCrop',
size=224,
scale=(0.2, 1.0),
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5),
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
]
train_dataloader = dict(
batch_size=128,
num_workers=32,
persistent_workers=True,
pin_memory=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline))

View File

@ -0,0 +1,24 @@
model = dict(
type='MixMIM',
data_preprocessor=dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='MixMIMTransformerPretrain',
arch='B',
drop_rate=0.0,
drop_path_rate=0.0, # drop_path_rate=0.0 during pretraining
),
neck=dict(
type='MixMIMPretrainDecoder',
num_patches=49,
encoder_stride=32,
embed_dim=1024,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16),
head=dict(
type='MixMIMPretrainHead',
norm_pix=True,
loss=dict(type='PixelReconstructionLoss', criterion='L2')))

View File

@ -0,0 +1,73 @@
# MixMIM
> [MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning](https://arxiv.org/abs/2205.13137)
<!-- [ALGORITHM] -->
## Abstract
In this study, we propose Mixed and Masked Image Modeling (MixMIM), a
simple but efficient MIM method that is applicable to various hierarchical Vision
Transformers. Existing MIM methods replace a random subset of input tokens with
a special \[MASK\] symbol and aim at reconstructing original image tokens from
the corrupted image. However, we find that using the \[MASK\] symbol greatly
slows down the training and causes training-finetuning inconsistency, due to the
large masking ratio (e.g., 40% in BEiT). In contrast, we replace the masked tokens
of one image with visible tokens of another image, i.e., creating a mixed image.
We then conduct dual reconstruction to reconstruct the original two images from
the mixed input, which significantly improves efficiency. While MixMIM can
be applied to various architectures, this paper explores a simpler but stronger
hierarchical Transformer, and scales with MixMIM-B, -L, and -H. Empirical
results demonstrate that MixMIM can learn high-quality visual representations
efficiently. Notably, MixMIM-B with 88M parameters achieves 85.1% top-1
accuracy on ImageNet-1K by pretraining for 600 epochs, setting a new record for
neural networks with comparable model sizes (e.g., ViT-B) among MIM methods.
Besides, its transferring performances on the other 6 datasets show MixMIM has
better FLOPs / performance tradeoff than previous MIM methods
<div align=center>
<img src="https://user-images.githubusercontent.com/56866854/202853730-d26fb3d7-e5e8-487a-aad5-e3d4600cef87.png"/>
</div>
## Models and Benchmarks
Here, we report the results of the model on ImageNet, the details are below:
<table class="docutils">
<thead>
<tr>
<th rowspan="2">Algorithm</th>
<th rowspan="2">Backbone</th>
<th rowspan="2">Epoch</th>
<th rowspan="2">Batch Size</th>
<th colspan="1" align="center">Results (Top-1 %)</th>
<th colspan="2" align="center">Links</th>
</tr>
<tr>
<th>Fine-tuning</th>
<th>Pretrain</th>
<th>Fine-tuning</th>
</tr>
</thead>
<tr>
<td>MixMIM</td>
<td>MixMIM-base</td>
<td>300</td>
<td>2048</td>
<td>84.63</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_16xb128-coslr-300e_in1k_20221208-44fe8d2c.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_16xb128-coslr-300e_in1k_20221204_134711.json'>log</a></td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/mixmim/classification/mixmim-base-p16_ft-8xb128-coslr-100e-in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221208-41ecada9.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221206_143046.json'>log</a></td>
</tr>
</tbody>
</table>
## Citation
```bibtex
@article{MixMIM2022,
author = {Jihao Liu, Xin Huang, Yu Liu, Hongsheng Li},
journal = {arXiv:2205.13137},
title = {MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning},
year = {2022},
}
```

View File

@ -0,0 +1,133 @@
_base_ = [
'mmcls::_base_/models/mixmim/mixmim_base.py',
'mmcls::_base_/datasets/imagenet_bs64_swin_224.py',
'mmcls::_base_/default_runtime.py'
]
dataset_type = 'ImageNet'
preprocess_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True,
)
bgr_mean = preprocess_cfg['mean'][::-1]
bgr_std = preprocess_cfg['std'][::-1]
dataset_type = 'ImageNet'
file_client_args = dict(backend='disk')
data_root = 'data/imagenet/'
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackClsInputs'),
]
train_dataloader = dict(
batch_size=128,
num_workers=16,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(
type='ResizeEdge',
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
]
val_dataloader = dict(
batch_size=64,
num_workers=8,
pin_memory=True,
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
test_dataloader = val_dataloader
# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='AdamW',
lr=5e-4 *
(8 * 128 / 256), # total_lr = base_lr*num_gpus*base_bs/256 = 2e-3
model_type='mixmim',
layer_decay_rate=0.7,
betas=(0.9, 0.999),
weight_decay=0.05),
constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
custom_keys={
'.ln': dict(decay_mult=0.0), # do not decay on ln and bias
'.bias': dict(decay_mult=0.0)
}))
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-6,
by_epoch=True,
begin=0,
end=5,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=95,
eta_min=1e-6,
by_epoch=True,
begin=5,
end=100,
convert_to_iter_based=True)
]
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=10)
val_cfg = dict()
test_cfg = dict()
default_hooks = dict(
# save checkpoint per epoch.
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
logger=dict(type='LoggerHook', interval=100))

View File

@ -0,0 +1,35 @@
Collections:
- Name: MixMIM
Metadata:
Training Data: ImageNet-1k
Training Techniques:
- AdamW
Training Resources: 16x A100-80G GPUs
Architecture:
- MixMIM ViT
Paper:
URL: https://arxiv.org/abs/2205.13137
Title: "MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning"
README: configs/selfsup/mixmim/README.md
Models:
- Name: mixmim-base-p16_16xb128-coslr-300e_in1k
In Collection: MixMIM
Metadata:
Epochs: 300
Batch Size: 2048
Results: null
Config: configs/selfsup/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_16xb128-coslr-300e_in1k_20221208-44fe8d2c.pth
Downstream:
- Type: Image Classification
Metadata:
Epochs: 100
Batch Size: 1024
Results:
- Task: Fine-tuning
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 84.63
Config: configs/selfsup/mixmim/classification/mixmim-base-p16_ft-8xb128-coslr-100e-in1k.py
Weights: https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221208-41ecada9.pth

View File

@ -0,0 +1,46 @@
_base_ = [
'../_base_/models/mixmim.py',
'../_base_/datasets/imagenet_mixmim.py',
'../_base_/schedules/adamw_coslr-200e_in1k.py',
'../_base_/default_runtime.py',
]
# optimizer wrapper
optimizer = dict(
type='AdamW',
lr=1.5e-4 *
(2 * 8 * 128 / 256), # total_lr = base_lr*num_gpus*base_bs/256 = 1.2e-3
betas=(0.9, 0.95),
weight_decay=0.05) # 2 node * 8 gpu * 128 batchsize
optim_wrapper = dict(
type='OptimWrapper',
optimizer=optimizer,
paramwise_cfg=dict(custom_keys={
'ln': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0)
}))
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=260,
by_epoch=True,
begin=40,
end=300,
convert_to_iter_based=True)
]
train_cfg = dict(max_epochs=300)
default_hooks = dict(
logger=dict(type='LoggerHook', interval=100),
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=1))
# randomness
randomness = dict(seed=0, diff_rank_seed=True)

View File

@ -430,5 +430,16 @@ ImageNet has multiple versions, but the most commonly used one is ILSVRC 2012. T
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/elfsup/eva/classification/vit-base-p16_linear-8xb2048-coslr-100e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k_20221226-ef51bf09.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k_20221222_134137.json'>log</a></td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/eva/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221226-f61cf992.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221221_212618.json'>log</a></td>
</tr>
<tr>
<td>MixMIM</td>
<td>MixMIM-Base</td>
<td>400</td>
<td>2048</td>
<td>/</td>
<td>84.6</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_16xb128-coslr-300e_in1k_20221208-44fe8d2c.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_16xb128-coslr-300e_in1k_20221204_134711.json'>log</a></td>
<td>/</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/mixmim/classification/mixmim-base-p16_ft-8xb128-coslr-100e-in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221208-41ecada9.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221206_143046.json'>log</a></td>
</tr>
</tbody>
</table>

View File

@ -430,5 +430,16 @@ ImageNet 有多个版本,不过最常用的是 ILSVRC 2012。我们提供了
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/elfsup/eva/classification/vit-base-p16_linear-8xb2048-coslr-100e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k_20221226-ef51bf09.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k_20221222_134137.json'>log</a></td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/eva/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221226-f61cf992.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221221_212618.json'>log</a></td>
</tr>
<tr>
<td>MixMIM</td>
<td>MixMIM-Base</td>
<td>400</td>
<td>2048</td>
<td>/</td>
<td>84.6</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_16xb128-coslr-300e_in1k_20221208-44fe8d2c.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_16xb128-coslr-300e_in1k_20221204_134711.json'>log</a></td>
<td>/</td>
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/mixmim/classification/mixmim-base-p16_ft-8xb128-coslr-100e-in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221208-41ecada9.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221206_143046.json'>log</a></td>
</tr>
</tbody>
</table>

View File

@ -60,6 +60,40 @@ def get_layer_id_for_swin(var_name: str, max_layer_id: int,
return max_layer_id - 1
def get_layer_id_for_mixmim(var_name: str, max_layer_id: int,
depths: List[int]) -> int:
"""Get the layer id to set the different learning rates for MixMIM.
The layer is from 1 to max_layer_id (e.g. 25)
Args:
var_name (str): The key of the model.
num_max_layer (int): Maximum number of backbone layers.
depths (List[int]): Depths for each stage.
Returns:
int: Returns the layer id of the key.
"""
if 'patch_embed' in var_name:
return -1
elif 'absolute_pos_embed' in var_name:
return -1
elif 'pos_embed' in var_name:
return -1
elif var_name.startswith('backbone.layers'):
layer_id = int(var_name.split('.')[2])
block_id = var_name.split('.')[4]
if block_id == 'downsample' or \
block_id == 'reduction' or \
block_id == 'norm':
return sum(depths[:layer_id + 1]) - 1
layer_id = sum(depths[:layer_id]) + int(block_id) + 1
return layer_id - 1
else:
return max_layer_id - 2
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
"""Different learning rates are set for different layers of backbone.
@ -113,13 +147,16 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
# currently, we only support layer-wise learning rate decay for vit
# and swin.
assert model_type in ['vit', 'swin'], f'Currently, we do not support \
assert model_type in ['vit', 'swin',
'mixmim'], f'Currently, we do not support \
layer-wise learning rate decay for {model_type}'
if model_type == 'vit':
num_layers = len(module.backbone.layers) + 2
elif model_type == 'swin':
num_layers = sum(module.backbone.depths) + 2
elif model_type == 'mixmim':
num_layers = sum(module.backbone.depths) + 1
# if layer_decay_rate is not provided, not decay
decay_rate = optimizer_cfg.pop('layer_decay_rate', 1.0)
@ -146,6 +183,9 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
elif model_type == 'swin':
layer_id = get_layer_id_for_swin(name, num_layers,
module.backbone.depths)
elif model_type == 'mixmim':
layer_id = get_layer_id_for_mixmim(name, num_layers,
module.backbone.depths)
group_name = f'layer_{layer_id}_{group_name}'
if group_name not in parameter_groups:

View File

@ -10,6 +10,7 @@ from .eva import EVA
from .mae import MAE
from .maskfeat import MaskFeat
from .milan import MILAN
from .mixmim import MixMIM
from .moco import MoCo
from .mocov3 import MoCoV3
from .npid import NPID
@ -24,5 +25,6 @@ from .swav import SwAV
__all__ = [
'BaseModel', 'BarlowTwins', 'BEiT', 'BYOL', 'DeepCluster', 'DenseCL',
'MoCo', 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR', 'SimSiam',
'SwAV', 'MAE', 'MoCoV3', 'SimMIM', 'CAE', 'MaskFeat', 'MILAN', 'EVA'
'SwAV', 'MAE', 'MoCoV3', 'SimMIM', 'CAE', 'MaskFeat', 'MILAN', 'EVA',
'MixMIM'
]

View File

@ -0,0 +1,54 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch
from torch import nn
from mmselfsup.registry import MODELS
from mmselfsup.structures import SelfSupDataSample
from .base import BaseModel
@MODELS.register_module()
class MixMIM(BaseModel):
"""MiXMIM.
Implementation of `MixMIM: Mixed and Masked Image Modeling for Efficient
Visual Representation Learning. <https://arxiv.org/abs/2205.13137>`_.
"""
def __init__(self,
backbone: dict,
neck: Optional[dict] = None,
head: Optional[dict] = None,
pretrained: Optional[str] = None,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
init_cfg: Optional[dict] = None):
head.update(dict(patch_size=neck.encoder_stride))
super().__init__(
backbone=backbone,
neck=neck,
head=head,
pretrained=pretrained,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
def loss(self, inputs: List[torch.Tensor],
data_samples: List[SelfSupDataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[SelfSupDataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
latent, mask = self.backbone(inputs[0])
x_rec = self.neck(latent, mask)
loss = self.head(x_rec, inputs[0], mask)
losses = dict(loss=loss)
return losses

View File

@ -4,6 +4,7 @@ from .cae_vit import CAEViT
from .mae_vit import MAEViT
from .maskfeat_vit import MaskFeatViT
from .milan_vit import MILANViT
from .mixmim_backbone import MixMIMTransformerPretrain
from .mocov3_vit import MoCoV3ViT
from .resnet import ResNet, ResNetSobel, ResNetV1d
from .resnext import ResNeXt
@ -11,5 +12,6 @@ from .simmim_swin import SimMIMSwinTransformer
__all__ = [
'ResNet', 'ResNetSobel', 'ResNetV1d', 'ResNeXt', 'MAEViT', 'MoCoV3ViT',
'SimMIMSwinTransformer', 'CAEViT', 'MaskFeatViT', 'BEiTViT', 'MILANViT'
'SimMIMSwinTransformer', 'CAEViT', 'MaskFeatViT', 'BEiTViT', 'MILANViT',
'MixMIMTransformerPretrain'
]

View File

@ -0,0 +1,200 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import List, Optional, Union
import torch
from mmcls.models.backbones import MixMIMTransformer
from torch import nn
from torch.nn import functional as F
from mmselfsup.registry import MODELS
from ..utils import build_2d_sincos_position_embedding
@MODELS.register_module()
class MixMIMTransformerPretrain(MixMIMTransformer):
"""MixMIM backbone during pretraining.
A PyTorch implement of : ` MixMIM: Mixed and Masked Image
Modeling for Efficient Visual Representation Learning
<https://arxiv.org/abs/2205.13137>`_
Args:
arch (str | dict): MixMIM architecture. If use string,
choose from 'base','large' and 'huge'.
If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
Defaults to 'base'.
mlp_ratio (int): The mlp ratio in FFN. Defaults to 4.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to mlp_ratio
the most common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
window_size (list): The height and width of the window.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
attn_drop_rate (float): Attention drop rate. Defaults to 0.
use_checkpoint (bool): Whether use the checkpoint to
reduce GPU memory cost
range_mask_ratio (float): The range of mask ratio.
Defaults to 0.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
arch: Union[str, dict] = 'base',
mlp_ratio: float = 4,
img_size: int = 224,
patch_size: int = 4,
in_channels: int = 3,
window_size: List = [14, 14, 14, 7],
qkv_bias: bool = True,
patch_cfg: dict = dict(),
norm_cfg: dict = dict(type='LN'),
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
attn_drop_rate: float = 0.0,
use_checkpoint: bool = False,
range_mask_ratio: float = 0.0,
init_cfg: Optional[dict] = None) -> None:
super().__init__(
arch=arch,
mlp_ratio=mlp_ratio,
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
window_size=window_size,
qkv_bias=qkv_bias,
patch_cfg=patch_cfg,
norm_cfg=norm_cfg,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
attn_drop_rate=attn_drop_rate,
use_checkpoint=use_checkpoint,
init_cfg=init_cfg)
self.range_mask_ratio = range_mask_ratio
def init_weights(self):
"""Initialize position embedding, patch embedding."""
super(MixMIMTransformer, self).init_weights()
pos_embed = build_2d_sincos_position_embedding(
int(self.num_patches**.5),
self.absolute_pos_embed.shape[-1],
cls_token=False)
self.absolute_pos_embed.data.copy_(pos_embed.float())
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def random_masking(self, x: torch.Tensor, mask_ratio: float = 0.5):
"""Generate the mask for MixMIM Pretraining.
Args:
x (torch.Tensor): Image with data augmentation applied, which is
of shape B x L x C.
mask_ratio (float): The mask ratio of total patches.
Defaults to 0.5.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- mask_s1 (torch.Tensor): mask with stride of
self.encoder_stride // 8.
- mask_s2 (torch.Tensor): mask with stride of
self.encoder_stride // 4.
- mask_s3 (torch.Tensor): mask with stride of
self.encoder_stride // 2.
- mask (torch.Tensor): mask with stride of
self.encoder_stride.
"""
B, C, H, W = x.shape
out_H = H // self.encoder_stride
out_W = W // self.encoder_stride
s3_H, s3_W = out_H * 2, out_W * 2
s2_H, s2_W = out_H * 4, out_W * 4
s1_H, s1_W = out_H * 8, out_W * 8
seq_l = out_H * out_W
# use a shared mask for a batch images
mask = torch.zeros([1, 1, seq_l], device=x.device)
mask_ratio = mask_ratio + random.uniform(0.0, self.range_mask_ratio)
noise = torch.rand(1, 1, seq_l, device=x.device) # noise in [0, 1]
# ascend: small is keep, large is removed
mask_idx = torch.argsort(noise, dim=2)[:, :, :int(seq_l * mask_ratio)]
mask.scatter_(2, mask_idx, 1)
mask = mask.reshape(1, 1, out_H, out_W)
mask_s1 = F.interpolate(mask, size=(s1_H, s1_W), mode='nearest')
mask_s2 = F.interpolate(mask, size=(s2_H, s2_W), mode='nearest')
mask_s3 = F.interpolate(mask, size=(s3_H, s3_W), mode='nearest')
mask = mask.reshape(1, out_H * out_W, 1).contiguous()
mask_s1 = mask_s1.reshape(1, s1_H * s1_W, 1).contiguous()
mask_s2 = mask_s2.reshape(1, s2_H * s2_W, 1).contiguous()
mask_s3 = mask_s3.reshape(1, s3_H * s3_W, 1).contiguous()
return mask_s1, mask_s2, mask_s3, mask
def forward(self, x: torch.Tensor, mask_ratio=0.5):
"""Generate features for masked images.
This function generates mask and masks some patches randomly and get
the hidden features for visible patches.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- x (torch.Tensor): hidden features, which is of shape
B x L x C.
- mask_s4 (torch.Tensor): the mask tensor for the last layer.
"""
mask_s1, mask_s2, mask_s3, mask_s4 = self.random_masking(x, mask_ratio)
x, _ = self.patch_embed(x)
x = x * (1. - mask_s1) + x.flip(0) * mask_s1
x = x + self.absolute_pos_embed
x = self.drop_after_pos(x)
for idx, layer in enumerate(self.layers):
if idx == 0:
x = layer(x, attn_mask=mask_s1)
elif idx == 1:
x = layer(x, attn_mask=mask_s2)
elif idx == 2:
x = layer(x, attn_mask=mask_s3)
elif idx == 3:
x = layer(x, attn_mask=mask_s4)
x = self.norm(x)
return x, mask_s4

View File

@ -8,6 +8,7 @@ from .latent_heads import LatentCrossCorrelationHead, LatentPredictHead
from .mae_head import MAEPretrainHead
from .maskfeat_head import MaskFeatPretrainHead
from .milan_head import MILANPretrainHead
from .mixmim_head import MixMIMPretrainHead
from .mocov3_head import MoCoV3Head
from .multi_cls_head import MultiClsHead
from .simmim_head import SimMIMHead
@ -17,5 +18,5 @@ __all__ = [
'BEiTV1Head', 'BEiTV2Head', 'ContrastiveHead', 'ClsHead',
'LatentPredictHead', 'LatentCrossCorrelationHead', 'MultiClsHead',
'MAEPretrainHead', 'MoCoV3Head', 'SimMIMHead', 'CAEHead', 'SwAVHead',
'MaskFeatPretrainHead', 'MILANPretrainHead'
'MaskFeatPretrainHead', 'MILANPretrainHead', 'MixMIMPretrainHead'
]

View File

@ -0,0 +1,49 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmselfsup.registry import MODELS
from .mae_head import MAEPretrainHead
@MODELS.register_module()
class MixMIMPretrainHead(MAEPretrainHead):
"""MixMIM pretrain head.
Args:
loss (dict): Config of loss.
norm_pix_loss (bool): Whether or not normalize target.
Defaults to False.
patch_size (int): Patch size. Defaults to 16.
"""
def __init__(self,
loss: dict,
norm_pix: bool = False,
patch_size: int = 16) -> None:
super().__init__(loss=loss, norm_pix=norm_pix, patch_size=patch_size)
def forward(self, x_rec: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""Forward function of MixMIM head.
Args:
pred (torch.Tensor): The reconstructed image.
target (torch.Tensor): The target image.
mask (torch.Tensor): The mask of the target image.
Returns:
torch.Tensor: The reconstruction loss.
"""
target = self.construct_target(target)
B, L, C = x_rec.shape
# unmix tokens
x1_rec = x_rec[:B // 2]
x2_rec = x_rec[B // 2:]
unmix_x_rec = x1_rec * mask + x2_rec.flip(0) * (1 - mask)
loss_rec = self.loss(unmix_x_rec, target)
return loss_rec

View File

@ -38,8 +38,10 @@ class PixelReconstructionLoss(BaseModule):
self.channel = channel if channel is not None else 1
def forward(self, pred: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
def forward(self,
pred: torch.Tensor,
target: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward function to compute the reconstrction loss.
Args:
@ -57,6 +59,9 @@ class PixelReconstructionLoss(BaseModule):
if len(loss.shape) == 3:
loss = loss.mean(dim=-1)
loss = (loss * mask).sum() / mask.sum() / self.channel
if mask is None:
loss = loss.mean()
else:
loss = (loss * mask).sum() / mask.sum() / self.channel
return loss

View File

@ -6,6 +6,7 @@ from .densecl_neck import DenseCLNeck
from .linear_neck import LinearNeck
from .mae_neck import ClsBatchNormNeck, MAEPretrainDecoder
from .milan_neck import MILANPretrainDecoder
from .mixmim_neck import MixMIMPretrainDecoder
from .mocov2_neck import MoCoV2Neck
from .nonlinear_neck import NonLinearNeck
from .odc_neck import ODCNeck
@ -16,6 +17,6 @@ from .swav_neck import SwAVNeck
__all__ = [
'AvgPool2dNeck', 'BEiTV2Neck', 'DenseCLNeck', 'LinearNeck', 'MoCoV2Neck',
'NonLinearNeck', 'ODCNeck', 'RelativeLocNeck', 'SwAVNeck',
'MAEPretrainDecoder', 'SimMIMNeck', 'CAENeck', 'ClsBatchNormNeck',
'MILANPretrainDecoder'
'MAEPretrainDecoder', 'SimMIMNeck', 'CAENeck', 'MixMIMPretrainDecoder',
'ClsBatchNormNeck', 'MILANPretrainDecoder'
]

View File

@ -0,0 +1,111 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
import torch.nn as nn
from mmselfsup.registry import MODELS
from ..utils import build_2d_sincos_position_embedding
from .mae_neck import MAEPretrainDecoder
@MODELS.register_module()
class MixMIMPretrainDecoder(MAEPretrainDecoder):
"""Decoder for MixMIM Pretraining.
Some of the code is borrowed from `https://github.com/Sense-X/MixMIM`. # noqa
Args:
num_patches (int): The number of total patches. Defaults to 196.
patch_size (int): Image patch size. Defaults to 16.
in_chans (int): The channel of input image. Defaults to 3.
embed_dim (int): Encoder's embedding dimension. Defaults to 1024.
encoder_stride (int): The output stride of MixMIM backbone. Defaults
to 32.
decoder_embed_dim (int): Decoder's embedding dimension.
Defaults to 512.
decoder_depth (int): The depth of decoder. Defaults to 8.
decoder_num_heads (int): Number of attention heads of decoder.
Defaults to 16.
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
Defaults to 4.
norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
init_cfg (Union[List[dict], dict], optional): Initialization config
dict. Defaults to None.
"""
def __init__(self,
num_patches: int = 196,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 1024,
encoder_stride: int = 32,
decoder_embed_dim: int = 512,
decoder_depth: int = 8,
decoder_num_heads: int = 16,
mlp_ratio: int = 4,
norm_cfg: dict = dict(type='LN', eps=1e-6),
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
num_patches=num_patches,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
decoder_embed_dim=decoder_embed_dim,
decoder_depth=decoder_depth,
decoder_num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
norm_cfg=norm_cfg,
init_cfg=init_cfg)
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, decoder_embed_dim),
requires_grad=False)
self.decoder_pred = nn.Linear(decoder_embed_dim, encoder_stride**2 * 3)
def init_weights(self) -> None:
"""Initialize position embedding and mask token of MixMIM decoder."""
super(MAEPretrainDecoder, self).init_weights()
decoder_pos_embed = build_2d_sincos_position_embedding(
int(self.num_patches**.5),
self.decoder_pos_embed.shape[-1],
cls_token=False)
self.decoder_pos_embed.data.copy_(decoder_pos_embed.float())
torch.nn.init.normal_(self.mask_token, std=.02)
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
x (torch.Tensor): The input features, which is of shape (N, L, C).
mask (torch.Tensor): The tensor to indicate which tokens a
re masked.
Returns:
torch.Tensor: The reconstructed features, which is of shape
(N, L, C).
"""
x = self.decoder_embed(x)
B, L, C = x.shape
mask_tokens = self.mask_token.expand(B, L, -1)
x1 = x * (1 - mask) + mask_tokens * mask
x2 = x * mask + mask_tokens * (1 - mask)
x = torch.cat([x1, x2], dim=0)
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for idx, blk in enumerate(self.decoder_blocks):
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
return x

View File

@ -21,3 +21,4 @@ Import:
- configs/selfsup/milan/metafile.yaml
- configs/selfsup/beitv2/metafile.yml
- configs/selfsup/eva/metafile.yaml
- configs/selfsup/mixmim/metafile.yaml