[Feature] Support MixMIM Pretrain and Finetuning (#626)
* [Feature] Support MixMIM Pretrain and Finetuning * [Feature] Fix Lint * [Feature] Add doc string for MixMIMTransformerPretrain * [Feature] Add doc string for MixMIMPretrainHead * [Feature] Add README * [Feature] Fix Config * [Feature] Fix Config * [Feature] Add doc string to mixmim neck and bacbone * [Feature] Add doc string to mixmim neck and bacbone * [Feature] Support MixMIM Pretrain and Finetuning * [Feature] Fix Lint * [Feature] Support MixMIM Pretrain and Finetuning * [Feature] Fix Lint * [Feature] Support MixMIM Pretrain and Finetuning * [Feature] Fix Lint * [Feature] Support MixMIM Pretrain and Finetuning * [Feature] Fix Lint * [Feature] Replace MixMIMTransformer with import from mmcls * add an explanation of the lr * add an explanation of the lr * [Feature] Fix lint * [Feature] Modification after Review * [Feature] Modification after Review2 * [Feature] Modification after Review2 * [Feature] Modification after Review2 * [Feature] Modification after Review3 * [Feature] Fix lint Co-authored-by: WasedaMagina <33023171+WasedaMagina@users.noreply.github.com>pull/681/head
parent
304e81650a
commit
a08faa1e11
|
@ -138,6 +138,7 @@ Supported algorithms:
|
|||
- [x] [MILAN (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/milan)
|
||||
- [x] [BEiT v2 (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/beitv2)
|
||||
- [x] [EVA (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/eva)
|
||||
- [x] [MixMIM (ArXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/mixmim)
|
||||
|
||||
More algorithms are in our plan.
|
||||
|
||||
|
|
|
@ -138,6 +138,7 @@ Useful Tools
|
|||
- [x] [MILAN (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/milan)
|
||||
- [x] [BEiT v2 (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/beitv2)
|
||||
- [x] [EVA (arXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/eva)
|
||||
- [x] [MixMIM (ArXiv'2022)](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/mixmim)
|
||||
|
||||
更多的算法实现已经在我们的计划中。
|
||||
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# dataset settings
|
||||
dataset_type = 'mmcls.ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
file_client_args = dict(backend='disk')
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=224,
|
||||
scale=(0.2, 1.0),
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PackSelfSupInputs', meta_keys=['img_path'])
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=128,
|
||||
num_workers=32,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix=dict(img_path='train/'),
|
||||
pipeline=train_pipeline))
|
|
@ -0,0 +1,24 @@
|
|||
model = dict(
|
||||
type='MixMIM',
|
||||
data_preprocessor=dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True),
|
||||
backbone=dict(
|
||||
type='MixMIMTransformerPretrain',
|
||||
arch='B',
|
||||
drop_rate=0.0,
|
||||
drop_path_rate=0.0, # drop_path_rate=0.0 during pretraining
|
||||
),
|
||||
neck=dict(
|
||||
type='MixMIMPretrainDecoder',
|
||||
num_patches=49,
|
||||
encoder_stride=32,
|
||||
embed_dim=1024,
|
||||
decoder_embed_dim=512,
|
||||
decoder_depth=8,
|
||||
decoder_num_heads=16),
|
||||
head=dict(
|
||||
type='MixMIMPretrainHead',
|
||||
norm_pix=True,
|
||||
loss=dict(type='PixelReconstructionLoss', criterion='L2')))
|
|
@ -0,0 +1,73 @@
|
|||
# MixMIM
|
||||
|
||||
> [MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning](https://arxiv.org/abs/2205.13137)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
In this study, we propose Mixed and Masked Image Modeling (MixMIM), a
|
||||
simple but efficient MIM method that is applicable to various hierarchical Vision
|
||||
Transformers. Existing MIM methods replace a random subset of input tokens with
|
||||
a special \[MASK\] symbol and aim at reconstructing original image tokens from
|
||||
the corrupted image. However, we find that using the \[MASK\] symbol greatly
|
||||
slows down the training and causes training-finetuning inconsistency, due to the
|
||||
large masking ratio (e.g., 40% in BEiT). In contrast, we replace the masked tokens
|
||||
of one image with visible tokens of another image, i.e., creating a mixed image.
|
||||
We then conduct dual reconstruction to reconstruct the original two images from
|
||||
the mixed input, which significantly improves efficiency. While MixMIM can
|
||||
be applied to various architectures, this paper explores a simpler but stronger
|
||||
hierarchical Transformer, and scales with MixMIM-B, -L, and -H. Empirical
|
||||
results demonstrate that MixMIM can learn high-quality visual representations
|
||||
efficiently. Notably, MixMIM-B with 88M parameters achieves 85.1% top-1
|
||||
accuracy on ImageNet-1K by pretraining for 600 epochs, setting a new record for
|
||||
neural networks with comparable model sizes (e.g., ViT-B) among MIM methods.
|
||||
Besides, its transferring performances on the other 6 datasets show MixMIM has
|
||||
better FLOPs / performance tradeoff than previous MIM methods
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/56866854/202853730-d26fb3d7-e5e8-487a-aad5-e3d4600cef87.png"/>
|
||||
</div>
|
||||
|
||||
## Models and Benchmarks
|
||||
|
||||
Here, we report the results of the model on ImageNet, the details are below:
|
||||
|
||||
<table class="docutils">
|
||||
<thead>
|
||||
<tr>
|
||||
<th rowspan="2">Algorithm</th>
|
||||
<th rowspan="2">Backbone</th>
|
||||
<th rowspan="2">Epoch</th>
|
||||
<th rowspan="2">Batch Size</th>
|
||||
<th colspan="1" align="center">Results (Top-1 %)</th>
|
||||
<th colspan="2" align="center">Links</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>Fine-tuning</th>
|
||||
<th>Pretrain</th>
|
||||
<th>Fine-tuning</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tr>
|
||||
<td>MixMIM</td>
|
||||
<td>MixMIM-base</td>
|
||||
<td>300</td>
|
||||
<td>2048</td>
|
||||
<td>84.63</td>
|
||||
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_16xb128-coslr-300e_in1k_20221208-44fe8d2c.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_16xb128-coslr-300e_in1k_20221204_134711.json'>log</a></td>
|
||||
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/mixmim/classification/mixmim-base-p16_ft-8xb128-coslr-100e-in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221208-41ecada9.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221206_143046.json'>log</a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{MixMIM2022,
|
||||
author = {Jihao Liu, Xin Huang, Yu Liu, Hongsheng Li},
|
||||
journal = {arXiv:2205.13137},
|
||||
title = {MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning},
|
||||
year = {2022},
|
||||
}
|
||||
```
|
|
@ -0,0 +1,133 @@
|
|||
_base_ = [
|
||||
'mmcls::_base_/models/mixmim/mixmim_base.py',
|
||||
'mmcls::_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'mmcls::_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
dataset_type = 'ImageNet'
|
||||
preprocess_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = preprocess_cfg['mean'][::-1]
|
||||
|
||||
bgr_std = preprocess_cfg['std'][::-1]
|
||||
dataset_type = 'ImageNet'
|
||||
file_client_args = dict(backend='disk')
|
||||
data_root = 'data/imagenet/'
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackClsInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=128,
|
||||
num_workers=16,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
persistent_workers=True,
|
||||
)
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=256,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackClsInputs'),
|
||||
]
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=8,
|
||||
pin_memory=True,
|
||||
collate_fn=dict(type='default_collate'),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
persistent_workers=True,
|
||||
)
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=5e-4 *
|
||||
(8 * 128 / 256), # total_lr = base_lr*num_gpus*base_bs/256 = 2e-3
|
||||
model_type='mixmim',
|
||||
layer_decay_rate=0.7,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.05),
|
||||
constructor='mmselfsup.LearningRateDecayOptimWrapperConstructor',
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'.ln': dict(decay_mult=0.0), # do not decay on ln and bias
|
||||
'.bias': dict(decay_mult=0.0)
|
||||
}))
|
||||
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-6,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=5,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=95,
|
||||
eta_min=1e-6,
|
||||
by_epoch=True,
|
||||
begin=5,
|
||||
end=100,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=10)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
default_hooks = dict(
|
||||
# save checkpoint per epoch.
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
|
||||
logger=dict(type='LoggerHook', interval=100))
|
|
@ -0,0 +1,35 @@
|
|||
Collections:
|
||||
- Name: MixMIM
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
Training Techniques:
|
||||
- AdamW
|
||||
Training Resources: 16x A100-80G GPUs
|
||||
Architecture:
|
||||
- MixMIM ViT
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/2205.13137
|
||||
Title: "MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning"
|
||||
README: configs/selfsup/mixmim/README.md
|
||||
|
||||
Models:
|
||||
- Name: mixmim-base-p16_16xb128-coslr-300e_in1k
|
||||
In Collection: MixMIM
|
||||
Metadata:
|
||||
Epochs: 300
|
||||
Batch Size: 2048
|
||||
Results: null
|
||||
Config: configs/selfsup/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k.py
|
||||
Weights: https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_16xb128-coslr-300e_in1k_20221208-44fe8d2c.pth
|
||||
Downstream:
|
||||
- Type: Image Classification
|
||||
Metadata:
|
||||
Epochs: 100
|
||||
Batch Size: 1024
|
||||
Results:
|
||||
- Task: Fine-tuning
|
||||
Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 84.63
|
||||
Config: configs/selfsup/mixmim/classification/mixmim-base-p16_ft-8xb128-coslr-100e-in1k.py
|
||||
Weights: https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221208-41ecada9.pth
|
|
@ -0,0 +1,46 @@
|
|||
_base_ = [
|
||||
'../_base_/models/mixmim.py',
|
||||
'../_base_/datasets/imagenet_mixmim.py',
|
||||
'../_base_/schedules/adamw_coslr-200e_in1k.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# optimizer wrapper
|
||||
optimizer = dict(
|
||||
type='AdamW',
|
||||
lr=1.5e-4 *
|
||||
(2 * 8 * 128 / 256), # total_lr = base_lr*num_gpus*base_bs/256 = 1.2e-3
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=0.05) # 2 node * 8 gpu * 128 batchsize
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
paramwise_cfg=dict(custom_keys={
|
||||
'ln': dict(decay_mult=0.0),
|
||||
'bias': dict(decay_mult=0.0)
|
||||
}))
|
||||
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=40,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=260,
|
||||
by_epoch=True,
|
||||
begin=40,
|
||||
end=300,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
train_cfg = dict(max_epochs=300)
|
||||
default_hooks = dict(
|
||||
logger=dict(type='LoggerHook', interval=100),
|
||||
checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=1))
|
||||
|
||||
# randomness
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
|
@ -430,5 +430,16 @@ ImageNet has multiple versions, but the most commonly used one is ILSVRC 2012. T
|
|||
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/elfsup/eva/classification/vit-base-p16_linear-8xb2048-coslr-100e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k_20221226-ef51bf09.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k_20221222_134137.json'>log</a></td>
|
||||
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/eva/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221226-f61cf992.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221221_212618.json'>log</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>MixMIM</td>
|
||||
<td>MixMIM-Base</td>
|
||||
<td>400</td>
|
||||
<td>2048</td>
|
||||
<td>/</td>
|
||||
<td>84.6</td>
|
||||
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_16xb128-coslr-300e_in1k_20221208-44fe8d2c.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_16xb128-coslr-300e_in1k_20221204_134711.json'>log</a></td>
|
||||
<td>/</td>
|
||||
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/mixmim/classification/mixmim-base-p16_ft-8xb128-coslr-100e-in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221208-41ecada9.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221206_143046.json'>log</a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
|
|
@ -430,5 +430,16 @@ ImageNet 有多个版本,不过最常用的是 ILSVRC 2012。我们提供了
|
|||
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/elfsup/eva/classification/vit-base-p16_linear-8xb2048-coslr-100e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k_20221226-ef51bf09.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k/vit-base-p16_linear-8xb2048-coslr-100e_in1k_20221222_134137.json'>log</a></td>
|
||||
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/eva/classification/vit-base-p16_ft-8xb128-coslr-100e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221226-f61cf992.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/eva/eva-mae-style_vit-base-p16_16xb256-coslr-400e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20221221_212618.json'>log</a></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>MixMIM</td>
|
||||
<td>MixMIM-Base</td>
|
||||
<td>400</td>
|
||||
<td>2048</td>
|
||||
<td>/</td>
|
||||
<td>84.6</td>
|
||||
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_16xb128-coslr-300e_in1k_20221208-44fe8d2c.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_16xb128-coslr-300e_in1k_20221204_134711.json'>log</a></td>
|
||||
<td>/</td>
|
||||
<td><a href='https://github.com/open-mmlab/mmselfsup/blob/dev-1.x/configs/selfsup/mixmim/classification/mixmim-base-p16_ft-8xb128-coslr-100e-in1k.py'>config</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221208-41ecada9.pth'>model</a> | <a href='https://download.openmmlab.com/mmselfsup/1.x/mixmim/mixmim-base-p16_16xb128-coslr-300e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k/mixmim-base-p16_ft-8xb128-coslr-100e_in1k_20221206_143046.json'>log</a></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
|
|
@ -60,6 +60,40 @@ def get_layer_id_for_swin(var_name: str, max_layer_id: int,
|
|||
return max_layer_id - 1
|
||||
|
||||
|
||||
def get_layer_id_for_mixmim(var_name: str, max_layer_id: int,
|
||||
depths: List[int]) -> int:
|
||||
"""Get the layer id to set the different learning rates for MixMIM.
|
||||
|
||||
The layer is from 1 to max_layer_id (e.g. 25)
|
||||
Args:
|
||||
var_name (str): The key of the model.
|
||||
num_max_layer (int): Maximum number of backbone layers.
|
||||
depths (List[int]): Depths for each stage.
|
||||
Returns:
|
||||
int: Returns the layer id of the key.
|
||||
"""
|
||||
|
||||
if 'patch_embed' in var_name:
|
||||
return -1
|
||||
elif 'absolute_pos_embed' in var_name:
|
||||
return -1
|
||||
elif 'pos_embed' in var_name:
|
||||
return -1
|
||||
elif var_name.startswith('backbone.layers'):
|
||||
layer_id = int(var_name.split('.')[2])
|
||||
block_id = var_name.split('.')[4]
|
||||
|
||||
if block_id == 'downsample' or \
|
||||
block_id == 'reduction' or \
|
||||
block_id == 'norm':
|
||||
return sum(depths[:layer_id + 1]) - 1
|
||||
|
||||
layer_id = sum(depths[:layer_id]) + int(block_id) + 1
|
||||
return layer_id - 1
|
||||
else:
|
||||
return max_layer_id - 2
|
||||
|
||||
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
|
||||
"""Different learning rates are set for different layers of backbone.
|
||||
|
@ -113,13 +147,16 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
|
|||
|
||||
# currently, we only support layer-wise learning rate decay for vit
|
||||
# and swin.
|
||||
assert model_type in ['vit', 'swin'], f'Currently, we do not support \
|
||||
assert model_type in ['vit', 'swin',
|
||||
'mixmim'], f'Currently, we do not support \
|
||||
layer-wise learning rate decay for {model_type}'
|
||||
|
||||
if model_type == 'vit':
|
||||
num_layers = len(module.backbone.layers) + 2
|
||||
elif model_type == 'swin':
|
||||
num_layers = sum(module.backbone.depths) + 2
|
||||
elif model_type == 'mixmim':
|
||||
num_layers = sum(module.backbone.depths) + 1
|
||||
|
||||
# if layer_decay_rate is not provided, not decay
|
||||
decay_rate = optimizer_cfg.pop('layer_decay_rate', 1.0)
|
||||
|
@ -146,6 +183,9 @@ class LearningRateDecayOptimWrapperConstructor(DefaultOptimWrapperConstructor):
|
|||
elif model_type == 'swin':
|
||||
layer_id = get_layer_id_for_swin(name, num_layers,
|
||||
module.backbone.depths)
|
||||
elif model_type == 'mixmim':
|
||||
layer_id = get_layer_id_for_mixmim(name, num_layers,
|
||||
module.backbone.depths)
|
||||
|
||||
group_name = f'layer_{layer_id}_{group_name}'
|
||||
if group_name not in parameter_groups:
|
||||
|
|
|
@ -10,6 +10,7 @@ from .eva import EVA
|
|||
from .mae import MAE
|
||||
from .maskfeat import MaskFeat
|
||||
from .milan import MILAN
|
||||
from .mixmim import MixMIM
|
||||
from .moco import MoCo
|
||||
from .mocov3 import MoCoV3
|
||||
from .npid import NPID
|
||||
|
@ -24,5 +25,6 @@ from .swav import SwAV
|
|||
__all__ = [
|
||||
'BaseModel', 'BarlowTwins', 'BEiT', 'BYOL', 'DeepCluster', 'DenseCL',
|
||||
'MoCo', 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR', 'SimSiam',
|
||||
'SwAV', 'MAE', 'MoCoV3', 'SimMIM', 'CAE', 'MaskFeat', 'MILAN', 'EVA'
|
||||
'SwAV', 'MAE', 'MoCoV3', 'SimMIM', 'CAE', 'MaskFeat', 'MILAN', 'EVA',
|
||||
'MixMIM'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from mmselfsup.registry import MODELS
|
||||
from mmselfsup.structures import SelfSupDataSample
|
||||
from .base import BaseModel
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MixMIM(BaseModel):
|
||||
"""MiXMIM.
|
||||
|
||||
Implementation of `MixMIM: Mixed and Masked Image Modeling for Efficient
|
||||
Visual Representation Learning. <https://arxiv.org/abs/2205.13137>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone: dict,
|
||||
neck: Optional[dict] = None,
|
||||
head: Optional[dict] = None,
|
||||
pretrained: Optional[str] = None,
|
||||
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
|
||||
head.update(dict(patch_size=neck.encoder_stride))
|
||||
super().__init__(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
pretrained=pretrained,
|
||||
data_preprocessor=data_preprocessor,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
def loss(self, inputs: List[torch.Tensor],
|
||||
data_samples: List[SelfSupDataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""The forward function in training.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]): The input images.
|
||||
data_samples (List[SelfSupDataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
latent, mask = self.backbone(inputs[0])
|
||||
x_rec = self.neck(latent, mask)
|
||||
loss = self.head(x_rec, inputs[0], mask)
|
||||
losses = dict(loss=loss)
|
||||
return losses
|
|
@ -4,6 +4,7 @@ from .cae_vit import CAEViT
|
|||
from .mae_vit import MAEViT
|
||||
from .maskfeat_vit import MaskFeatViT
|
||||
from .milan_vit import MILANViT
|
||||
from .mixmim_backbone import MixMIMTransformerPretrain
|
||||
from .mocov3_vit import MoCoV3ViT
|
||||
from .resnet import ResNet, ResNetSobel, ResNetV1d
|
||||
from .resnext import ResNeXt
|
||||
|
@ -11,5 +12,6 @@ from .simmim_swin import SimMIMSwinTransformer
|
|||
|
||||
__all__ = [
|
||||
'ResNet', 'ResNetSobel', 'ResNetV1d', 'ResNeXt', 'MAEViT', 'MoCoV3ViT',
|
||||
'SimMIMSwinTransformer', 'CAEViT', 'MaskFeatViT', 'BEiTViT', 'MILANViT'
|
||||
'SimMIMSwinTransformer', 'CAEViT', 'MaskFeatViT', 'BEiTViT', 'MILANViT',
|
||||
'MixMIMTransformerPretrain'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,200 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import random
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from mmcls.models.backbones import MixMIMTransformer
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmselfsup.registry import MODELS
|
||||
from ..utils import build_2d_sincos_position_embedding
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MixMIMTransformerPretrain(MixMIMTransformer):
|
||||
"""MixMIM backbone during pretraining.
|
||||
|
||||
A PyTorch implement of : ` MixMIM: Mixed and Masked Image
|
||||
Modeling for Efficient Visual Representation Learning
|
||||
<https://arxiv.org/abs/2205.13137>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): MixMIM architecture. If use string,
|
||||
choose from 'base','large' and 'huge'.
|
||||
If use dict, it should have below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **depths** (int): The number of transformer encoder layers.
|
||||
- **num_heads** (int): The number of heads in attention modules.
|
||||
|
||||
Defaults to 'base'.
|
||||
mlp_ratio (int): The mlp ratio in FFN. Defaults to 4.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to mlp_ratio
|
||||
the most common input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 16.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
window_size (list): The height and width of the window.
|
||||
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
||||
Defaults to True.
|
||||
patch_cfg (dict): Extra config dict for patch embedding.
|
||||
Defaults to an empty dict.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
||||
attn_drop_rate (float): Attention drop rate. Defaults to 0.
|
||||
use_checkpoint (bool): Whether use the checkpoint to
|
||||
reduce GPU memory cost
|
||||
range_mask_ratio (float): The range of mask ratio.
|
||||
Defaults to 0.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch: Union[str, dict] = 'base',
|
||||
mlp_ratio: float = 4,
|
||||
img_size: int = 224,
|
||||
patch_size: int = 4,
|
||||
in_channels: int = 3,
|
||||
window_size: List = [14, 14, 14, 7],
|
||||
qkv_bias: bool = True,
|
||||
patch_cfg: dict = dict(),
|
||||
norm_cfg: dict = dict(type='LN'),
|
||||
drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
attn_drop_rate: float = 0.0,
|
||||
use_checkpoint: bool = False,
|
||||
range_mask_ratio: float = 0.0,
|
||||
init_cfg: Optional[dict] = None) -> None:
|
||||
|
||||
super().__init__(
|
||||
arch=arch,
|
||||
mlp_ratio=mlp_ratio,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
window_size=window_size,
|
||||
qkv_bias=qkv_bias,
|
||||
patch_cfg=patch_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
use_checkpoint=use_checkpoint,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.range_mask_ratio = range_mask_ratio
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize position embedding, patch embedding."""
|
||||
super(MixMIMTransformer, self).init_weights()
|
||||
|
||||
pos_embed = build_2d_sincos_position_embedding(
|
||||
int(self.num_patches**.5),
|
||||
self.absolute_pos_embed.shape[-1],
|
||||
cls_token=False)
|
||||
self.absolute_pos_embed.data.copy_(pos_embed.float())
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
# we use xavier_uniform following official JAX ViT:
|
||||
torch.nn.init.xavier_uniform_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def random_masking(self, x: torch.Tensor, mask_ratio: float = 0.5):
|
||||
"""Generate the mask for MixMIM Pretraining.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Image with data augmentation applied, which is
|
||||
of shape B x L x C.
|
||||
mask_ratio (float): The mask ratio of total patches.
|
||||
Defaults to 0.5.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
- mask_s1 (torch.Tensor): mask with stride of
|
||||
self.encoder_stride // 8.
|
||||
- mask_s2 (torch.Tensor): mask with stride of
|
||||
self.encoder_stride // 4.
|
||||
- mask_s3 (torch.Tensor): mask with stride of
|
||||
self.encoder_stride // 2.
|
||||
- mask (torch.Tensor): mask with stride of
|
||||
self.encoder_stride.
|
||||
"""
|
||||
|
||||
B, C, H, W = x.shape
|
||||
out_H = H // self.encoder_stride
|
||||
out_W = W // self.encoder_stride
|
||||
s3_H, s3_W = out_H * 2, out_W * 2
|
||||
s2_H, s2_W = out_H * 4, out_W * 4
|
||||
s1_H, s1_W = out_H * 8, out_W * 8
|
||||
|
||||
seq_l = out_H * out_W
|
||||
# use a shared mask for a batch images
|
||||
mask = torch.zeros([1, 1, seq_l], device=x.device)
|
||||
|
||||
mask_ratio = mask_ratio + random.uniform(0.0, self.range_mask_ratio)
|
||||
noise = torch.rand(1, 1, seq_l, device=x.device) # noise in [0, 1]
|
||||
# ascend: small is keep, large is removed
|
||||
mask_idx = torch.argsort(noise, dim=2)[:, :, :int(seq_l * mask_ratio)]
|
||||
mask.scatter_(2, mask_idx, 1)
|
||||
mask = mask.reshape(1, 1, out_H, out_W)
|
||||
mask_s1 = F.interpolate(mask, size=(s1_H, s1_W), mode='nearest')
|
||||
mask_s2 = F.interpolate(mask, size=(s2_H, s2_W), mode='nearest')
|
||||
mask_s3 = F.interpolate(mask, size=(s3_H, s3_W), mode='nearest')
|
||||
|
||||
mask = mask.reshape(1, out_H * out_W, 1).contiguous()
|
||||
mask_s1 = mask_s1.reshape(1, s1_H * s1_W, 1).contiguous()
|
||||
mask_s2 = mask_s2.reshape(1, s2_H * s2_W, 1).contiguous()
|
||||
mask_s3 = mask_s3.reshape(1, s3_H * s3_W, 1).contiguous()
|
||||
|
||||
return mask_s1, mask_s2, mask_s3, mask
|
||||
|
||||
def forward(self, x: torch.Tensor, mask_ratio=0.5):
|
||||
"""Generate features for masked images.
|
||||
|
||||
This function generates mask and masks some patches randomly and get
|
||||
the hidden features for visible patches.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images, which is of shape B x C x H x W.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
- x (torch.Tensor): hidden features, which is of shape
|
||||
B x L x C.
|
||||
- mask_s4 (torch.Tensor): the mask tensor for the last layer.
|
||||
"""
|
||||
|
||||
mask_s1, mask_s2, mask_s3, mask_s4 = self.random_masking(x, mask_ratio)
|
||||
|
||||
x, _ = self.patch_embed(x)
|
||||
|
||||
x = x * (1. - mask_s1) + x.flip(0) * mask_s1
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
for idx, layer in enumerate(self.layers):
|
||||
if idx == 0:
|
||||
x = layer(x, attn_mask=mask_s1)
|
||||
elif idx == 1:
|
||||
x = layer(x, attn_mask=mask_s2)
|
||||
elif idx == 2:
|
||||
x = layer(x, attn_mask=mask_s3)
|
||||
elif idx == 3:
|
||||
x = layer(x, attn_mask=mask_s4)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, mask_s4
|
|
@ -8,6 +8,7 @@ from .latent_heads import LatentCrossCorrelationHead, LatentPredictHead
|
|||
from .mae_head import MAEPretrainHead
|
||||
from .maskfeat_head import MaskFeatPretrainHead
|
||||
from .milan_head import MILANPretrainHead
|
||||
from .mixmim_head import MixMIMPretrainHead
|
||||
from .mocov3_head import MoCoV3Head
|
||||
from .multi_cls_head import MultiClsHead
|
||||
from .simmim_head import SimMIMHead
|
||||
|
@ -17,5 +18,5 @@ __all__ = [
|
|||
'BEiTV1Head', 'BEiTV2Head', 'ContrastiveHead', 'ClsHead',
|
||||
'LatentPredictHead', 'LatentCrossCorrelationHead', 'MultiClsHead',
|
||||
'MAEPretrainHead', 'MoCoV3Head', 'SimMIMHead', 'CAEHead', 'SwAVHead',
|
||||
'MaskFeatPretrainHead', 'MILANPretrainHead'
|
||||
'MaskFeatPretrainHead', 'MILANPretrainHead', 'MixMIMPretrainHead'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.registry import MODELS
|
||||
from .mae_head import MAEPretrainHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MixMIMPretrainHead(MAEPretrainHead):
|
||||
"""MixMIM pretrain head.
|
||||
|
||||
Args:
|
||||
loss (dict): Config of loss.
|
||||
norm_pix_loss (bool): Whether or not normalize target.
|
||||
Defaults to False.
|
||||
patch_size (int): Patch size. Defaults to 16.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss: dict,
|
||||
norm_pix: bool = False,
|
||||
patch_size: int = 16) -> None:
|
||||
super().__init__(loss=loss, norm_pix=norm_pix, patch_size=patch_size)
|
||||
|
||||
def forward(self, x_rec: torch.Tensor, target: torch.Tensor,
|
||||
mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function of MixMIM head.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The reconstructed image.
|
||||
target (torch.Tensor): The target image.
|
||||
mask (torch.Tensor): The mask of the target image.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The reconstruction loss.
|
||||
"""
|
||||
target = self.construct_target(target)
|
||||
|
||||
B, L, C = x_rec.shape
|
||||
|
||||
# unmix tokens
|
||||
x1_rec = x_rec[:B // 2]
|
||||
x2_rec = x_rec[B // 2:]
|
||||
|
||||
unmix_x_rec = x1_rec * mask + x2_rec.flip(0) * (1 - mask)
|
||||
|
||||
loss_rec = self.loss(unmix_x_rec, target)
|
||||
|
||||
return loss_rec
|
|
@ -38,8 +38,10 @@ class PixelReconstructionLoss(BaseModule):
|
|||
|
||||
self.channel = channel if channel is not None else 1
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor,
|
||||
mask: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Forward function to compute the reconstrction loss.
|
||||
|
||||
Args:
|
||||
|
@ -57,6 +59,9 @@ class PixelReconstructionLoss(BaseModule):
|
|||
if len(loss.shape) == 3:
|
||||
loss = loss.mean(dim=-1)
|
||||
|
||||
loss = (loss * mask).sum() / mask.sum() / self.channel
|
||||
if mask is None:
|
||||
loss = loss.mean()
|
||||
else:
|
||||
loss = (loss * mask).sum() / mask.sum() / self.channel
|
||||
|
||||
return loss
|
||||
|
|
|
@ -6,6 +6,7 @@ from .densecl_neck import DenseCLNeck
|
|||
from .linear_neck import LinearNeck
|
||||
from .mae_neck import ClsBatchNormNeck, MAEPretrainDecoder
|
||||
from .milan_neck import MILANPretrainDecoder
|
||||
from .mixmim_neck import MixMIMPretrainDecoder
|
||||
from .mocov2_neck import MoCoV2Neck
|
||||
from .nonlinear_neck import NonLinearNeck
|
||||
from .odc_neck import ODCNeck
|
||||
|
@ -16,6 +17,6 @@ from .swav_neck import SwAVNeck
|
|||
__all__ = [
|
||||
'AvgPool2dNeck', 'BEiTV2Neck', 'DenseCLNeck', 'LinearNeck', 'MoCoV2Neck',
|
||||
'NonLinearNeck', 'ODCNeck', 'RelativeLocNeck', 'SwAVNeck',
|
||||
'MAEPretrainDecoder', 'SimMIMNeck', 'CAENeck', 'ClsBatchNormNeck',
|
||||
'MILANPretrainDecoder'
|
||||
'MAEPretrainDecoder', 'SimMIMNeck', 'CAENeck', 'MixMIMPretrainDecoder',
|
||||
'ClsBatchNormNeck', 'MILANPretrainDecoder'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmselfsup.registry import MODELS
|
||||
from ..utils import build_2d_sincos_position_embedding
|
||||
from .mae_neck import MAEPretrainDecoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MixMIMPretrainDecoder(MAEPretrainDecoder):
|
||||
"""Decoder for MixMIM Pretraining.
|
||||
|
||||
Some of the code is borrowed from `https://github.com/Sense-X/MixMIM`. # noqa
|
||||
|
||||
Args:
|
||||
num_patches (int): The number of total patches. Defaults to 196.
|
||||
patch_size (int): Image patch size. Defaults to 16.
|
||||
in_chans (int): The channel of input image. Defaults to 3.
|
||||
embed_dim (int): Encoder's embedding dimension. Defaults to 1024.
|
||||
encoder_stride (int): The output stride of MixMIM backbone. Defaults
|
||||
to 32.
|
||||
decoder_embed_dim (int): Decoder's embedding dimension.
|
||||
Defaults to 512.
|
||||
decoder_depth (int): The depth of decoder. Defaults to 8.
|
||||
decoder_num_heads (int): Number of attention heads of decoder.
|
||||
Defaults to 16.
|
||||
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
|
||||
Defaults to 4.
|
||||
norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
|
||||
init_cfg (Union[List[dict], dict], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_patches: int = 196,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 1024,
|
||||
encoder_stride: int = 32,
|
||||
decoder_embed_dim: int = 512,
|
||||
decoder_depth: int = 8,
|
||||
decoder_num_heads: int = 16,
|
||||
mlp_ratio: int = 4,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
|
||||
super().__init__(
|
||||
num_patches=num_patches,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
decoder_embed_dim=decoder_embed_dim,
|
||||
decoder_depth=decoder_depth,
|
||||
decoder_num_heads=decoder_num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
norm_cfg=norm_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.decoder_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches, decoder_embed_dim),
|
||||
requires_grad=False)
|
||||
self.decoder_pred = nn.Linear(decoder_embed_dim, encoder_stride**2 * 3)
|
||||
|
||||
def init_weights(self) -> None:
|
||||
"""Initialize position embedding and mask token of MixMIM decoder."""
|
||||
super(MAEPretrainDecoder, self).init_weights()
|
||||
|
||||
decoder_pos_embed = build_2d_sincos_position_embedding(
|
||||
int(self.num_patches**.5),
|
||||
self.decoder_pos_embed.shape[-1],
|
||||
cls_token=False)
|
||||
self.decoder_pos_embed.data.copy_(decoder_pos_embed.float())
|
||||
|
||||
torch.nn.init.normal_(self.mask_token, std=.02)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input features, which is of shape (N, L, C).
|
||||
mask (torch.Tensor): The tensor to indicate which tokens a
|
||||
re masked.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The reconstructed features, which is of shape
|
||||
(N, L, C).
|
||||
"""
|
||||
|
||||
x = self.decoder_embed(x)
|
||||
B, L, C = x.shape
|
||||
|
||||
mask_tokens = self.mask_token.expand(B, L, -1)
|
||||
x1 = x * (1 - mask) + mask_tokens * mask
|
||||
x2 = x * mask + mask_tokens * (1 - mask)
|
||||
x = torch.cat([x1, x2], dim=0)
|
||||
|
||||
# add pos embed
|
||||
x = x + self.decoder_pos_embed
|
||||
|
||||
# apply Transformer blocks
|
||||
for idx, blk in enumerate(self.decoder_blocks):
|
||||
x = blk(x)
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
# predictor projection
|
||||
x = self.decoder_pred(x)
|
||||
|
||||
return x
|
|
@ -21,3 +21,4 @@ Import:
|
|||
- configs/selfsup/milan/metafile.yaml
|
||||
- configs/selfsup/beitv2/metafile.yml
|
||||
- configs/selfsup/eva/metafile.yaml
|
||||
- configs/selfsup/mixmim/metafile.yaml
|
||||
|
|
Loading…
Reference in New Issue