[Feature] Add mixmim backbone with checkpoints. (#1224)
* add mixmim backbone * add mixmim inference * add docstring, metafile, test and modify readme * Update README and metafile Co-authored-by: mzr1996 <mzr1996@163.com>pull/1258/head^2
parent
7dcf34533d
commit
14dcb69092
|
@ -154,6 +154,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
|
|||
- [x] [RepLKNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/replknet)
|
||||
- [x] [BEiT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beit) / [BEiT v2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beitv2)
|
||||
- [x] [EVA](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/eva)
|
||||
- [x] [MixMIM](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mixmim)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -155,6 +155,7 @@ mim install -e .
|
|||
- [x] [RepLKNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/replknet)
|
||||
- [x] [BEiT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beit) / [BEiT v2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beitv2)
|
||||
- [x] [EVA](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/eva)
|
||||
- [x] [MixMIM](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mixmim)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='MixMIMTransformer', arch='B', drop_rate=0.0, drop_path_rate=0.1),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1024,
|
||||
init_cfg=None,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]))
|
|
@ -0,0 +1,90 @@
|
|||
# MixMIM
|
||||
|
||||
> [MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning](https://arxiv.org/abs/2205.13137)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
In this study, we propose Mixed and Masked Image Modeling (MixMIM), a
|
||||
simple but efficient MIM method that is applicable to various hierarchical Vision
|
||||
Transformers. Existing MIM methods replace a random subset of input tokens with
|
||||
a special [MASK] symbol and aim at reconstructing original image tokens from
|
||||
the corrupted image. However, we find that using the [MASK] symbol greatly
|
||||
slows down the training and causes training-finetuning inconsistency, due to the
|
||||
large masking ratio (e.g., 40% in BEiT). In contrast, we replace the masked tokens
|
||||
of one image with visible tokens of another image, i.e., creating a mixed image.
|
||||
We then conduct dual reconstruction to reconstruct the original two images from
|
||||
the mixed input, which significantly improves efficiency. While MixMIM can
|
||||
be applied to various architectures, this paper explores a simpler but stronger
|
||||
hierarchical Transformer, and scales with MixMIM-B, -L, and -H. Empirical
|
||||
results demonstrate that MixMIM can learn high-quality visual representations
|
||||
efficiently. Notably, MixMIM-B with 88M parameters achieves 85.1% top-1
|
||||
accuracy on ImageNet-1K by pretraining for 600 epochs, setting a new record for
|
||||
neural networks with comparable model sizes (e.g., ViT-B) among MIM methods.
|
||||
Besides, its transferring performances on the other 6 datasets show MixMIM has
|
||||
better FLOPs / performance tradeoff than previous MIM methods
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/56866854/202853730-d26fb3d7-e5e8-487a-aad5-e3d4600cef87.png"/>
|
||||
</div>
|
||||
|
||||
## How to use it?
|
||||
|
||||
### Inference
|
||||
|
||||
<!-- [TABS-BEGIN] -->
|
||||
|
||||
**Predict image**
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import mmcls
|
||||
>>> model = mmcls.get_model('mixmim-base_3rdparty_in1k', pretrained=True)
|
||||
>>> predict = mmcls.inference_model(model, 'demo/demo.JPEG')
|
||||
>>> print(predict['pred_class'])
|
||||
sea snake
|
||||
>>> print(predict['pred_score'])
|
||||
0.865431010723114
|
||||
```
|
||||
|
||||
**Use the model**
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import mmcls
|
||||
>>>
|
||||
>>> model = mmcls.get_model('mixmim-base_3rdparty_in1k', pretrained=True)
|
||||
>>> inputs = torch.rand(1, 3, 224, 224)
|
||||
>>> # To get classification scores.
|
||||
>>> out = model(inputs)
|
||||
>>> print(out.shape)
|
||||
torch.Size([1, 1000])
|
||||
>>> # To extract features.
|
||||
>>> outs = model.extract_feat(inputs)
|
||||
>>> print(outs[0].shape)
|
||||
torch.Size([1, 1024])
|
||||
```
|
||||
|
||||
<!-- [TABS-END] -->
|
||||
|
||||
## Models
|
||||
|
||||
| Model | Params(M) | Pretrain Epochs | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
| :-------------------------: | :-------: | :-------------: | :------: | :-------: | :-------: | :-----------------------------------: | :------------------------------------------------------------------------------------: |
|
||||
| mixmim-base_3rdparty_in1k\* | 88 | 300 | 16.3 | 84.6 | 97.0 | [config](./mixmim-base_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mixmim/mixmim-base_3rdparty_in1k_20221206-e40e2c8c.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/Sense-X/MixMIM). The config files of these models are only for inference.*
|
||||
|
||||
For MixMIM self-supervised learning algorithm, welcome to [MMSelfSup page](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/mixmim) to get more information.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{MixMIM2022,
|
||||
author = {Jihao Liu, Xin Huang, Yu Liu, Hongsheng Li},
|
||||
journal = {arXiv:2205.13137},
|
||||
title = {MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning},
|
||||
year = {2022},
|
||||
}
|
||||
```
|
|
@ -0,0 +1,39 @@
|
|||
Collections:
|
||||
- Name: MixMIM
|
||||
Metadata:
|
||||
Architecture:
|
||||
- Attention Dropout
|
||||
- Convolution
|
||||
- Dense Connections
|
||||
- Dropout
|
||||
- GELU
|
||||
- Layer Normalization
|
||||
- Multi-Head Attention
|
||||
- Scaled Dot-Product Attention
|
||||
- Tanh Activation
|
||||
Paper:
|
||||
Title: 'MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning'
|
||||
URL: https://arxiv.org/abs/2205.13137
|
||||
README: configs/mixmim/README.md
|
||||
Code:
|
||||
URL: https://github.com/open-mmlab/mmclassification/blob/dev-1.x/mmcls/models/backbones/mixmim.py
|
||||
Version: v1.0.0rc4
|
||||
|
||||
Models:
|
||||
- Name: mixmim-base_3rdparty_in1k
|
||||
Metadata:
|
||||
FLOPs: 16352000000
|
||||
Parameters: 88344000
|
||||
Training Data:
|
||||
- ImageNet-1k
|
||||
In Collection: MixMIM
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 84.6
|
||||
Top 5 Accuracy: 97.0
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/mixmim/mixmim-base_3rdparty_in1k_20221206-e40e2c8c.pth
|
||||
Config: configs/mixmim/mixmim-base_8xb64_in1k.py
|
||||
Converted From:
|
||||
Code: https://github.com/Sense-X/MixMIM
|
|
@ -0,0 +1,5 @@
|
|||
_base_ = [
|
||||
'../_base_/models/mixmim/mixmim_base.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
|
||||
]
|
|
@ -16,6 +16,7 @@ from .hornet import HorNet
|
|||
from .hrnet import HRNet
|
||||
from .inception_v3 import InceptionV3
|
||||
from .lenet import LeNet5
|
||||
from .mixmim import MixMIMTransformer
|
||||
from .mlp_mixer import MlpMixer
|
||||
from .mobilenet_v2 import MobileNetV2
|
||||
from .mobilenet_v3 import MobileNetV3
|
||||
|
@ -102,5 +103,6 @@ __all__ = [
|
|||
'DaViT',
|
||||
'BEiT',
|
||||
'RevVisionTransformer',
|
||||
'MixMIMTransformer',
|
||||
'TinyViT',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,494 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.drop import DropPath
|
||||
from mmcv.cnn.bricks.transformer import PatchEmbed, PatchMerging
|
||||
from mmengine.model import BaseModule
|
||||
from torch import nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
|
||||
from mmcls.models.utils.attention import WindowMSA
|
||||
from mmcls.models.utils.helpers import to_2tuple
|
||||
from mmcls.registry import MODELS
|
||||
|
||||
|
||||
class MixMIMWindowAttention(WindowMSA):
|
||||
"""MixMIM Window Attention.
|
||||
|
||||
Compared with WindowMSA, we add some modifications
|
||||
in ``forward`` to meet the requirement of MixMIM during
|
||||
pretraining.
|
||||
|
||||
Implements one windown attention in MixMIM.
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
window_size (list): The height and width of the window.
|
||||
num_heads (int): The number of head in attention.
|
||||
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
||||
Defaults to True.
|
||||
qk_scale (float, optional): Override default qk scale of
|
||||
``head_dim ** -0.5`` if set. Defaults to None.
|
||||
attn_drop_rate (float): attention drop rate.
|
||||
Defaults to 0.
|
||||
proj_drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
window_size,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
init_cfg=None):
|
||||
|
||||
super().__init__(
|
||||
embed_dims=embed_dims,
|
||||
window_size=window_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=proj_drop_rate,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[
|
||||
2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1],
|
||||
-1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.reshape(B_, 1, 1, N)
|
||||
mask_new = mask * mask.transpose(
|
||||
2, 3) + (1 - mask) * (1 - mask).transpose(2, 3)
|
||||
mask_new = 1 - mask_new
|
||||
|
||||
if mask_new.dtype == torch.float16:
|
||||
attn = attn - 65500 * mask_new
|
||||
else:
|
||||
attn = attn - 1e30 * mask_new
|
||||
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class MixMIMBlock(TransformerEncoderLayer):
|
||||
"""MixMIM Block. Implements one block in MixMIM.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
input_resolution (tuple): Input resolution of this layer.
|
||||
num_heads (int): The number of head in attention,
|
||||
window_size (list): The height and width of the window.
|
||||
mlp_ratio (int): The MLP ration in FFN.
|
||||
num_fcs (int): The number of linear layers in a block.
|
||||
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
||||
Defaults to True.
|
||||
proj_drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
attn_drop_rate (float): attention drop rate.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate.
|
||||
Defaults to 0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
num_fcs=2,
|
||||
qkv_bias=True,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
|
||||
super().__init__(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=int(mlp_ratio * embed_dims),
|
||||
drop_rate=proj_drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.input_resolution = input_resolution
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
if min(self.input_resolution) <= self.window_size:
|
||||
self.window_size = min(self.input_resolution)
|
||||
|
||||
self.attn = MixMIMWindowAttention(
|
||||
embed_dims=embed_dims,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
proj_drop_rate=proj_drop_rate)
|
||||
|
||||
self.drop_path = DropPath(
|
||||
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
||||
|
||||
@staticmethod
|
||||
def window_reverse(windows, H, W, window_size):
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def window_partition(x, window_size):
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size,
|
||||
window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
||||
windows = windows.view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# partition windows
|
||||
x_windows = self.window_partition(
|
||||
x, self.window_size) # nW*B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size,
|
||||
C) # nW*B, window_size*window_size, C
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.repeat(B, 1, 1) # B, N, 1
|
||||
attn_mask = attn_mask.view(B, H, W, 1)
|
||||
attn_mask = self.window_partition(attn_mask, self.window_size)
|
||||
attn_mask = attn_mask.view(-1, self.window_size * self.window_size,
|
||||
1)
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_windows = self.attn(
|
||||
x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size,
|
||||
self.window_size, C)
|
||||
x = self.window_reverse(attn_windows, H, W,
|
||||
self.window_size) # B H' W' C
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
x = shortcut + self.drop_path(x)
|
||||
|
||||
x = self.ffn(self.norm2(x), identity=x) # ffn contains DropPath
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MixMIMLayer(BaseModule):
|
||||
"""Implements one MixMIM layer, which may contains several MixMIM blocks.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
input_resolution (tuple): Input resolution of this layer.
|
||||
depth (int): The number of blocks in this layer.
|
||||
num_heads (int): The number of head in attention,
|
||||
window_size (list): The height and width of the window.
|
||||
mlp_ratio (int): The MLP ration in FFN.
|
||||
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
||||
Defaults to True.
|
||||
proj_drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
attn_drop_rate (float): attention drop rate.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate.
|
||||
Defaults to 0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
downsample (class, optional): Downsample the output of blocks b
|
||||
y patch merging.Defaults to None.
|
||||
use_checkpoint (bool): Whether use the checkpoint to
|
||||
reduce GPU memory cost.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims: int,
|
||||
input_resolution: int,
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
window_size: int,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=[0.],
|
||||
norm_cfg=dict(type='LN'),
|
||||
downsample=None,
|
||||
use_checkpoint=False,
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.embed_dims = embed_dims
|
||||
self.input_resolution = input_resolution
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList()
|
||||
for i in range(depth):
|
||||
self.blocks.append(
|
||||
MixMIMBlock(
|
||||
embed_dims=embed_dims,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop_rate=proj_drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rate[i],
|
||||
norm_cfg=norm_cfg))
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(
|
||||
in_channels=embed_dims,
|
||||
out_channels=2 * embed_dims,
|
||||
norm_cfg=norm_cfg)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint(blk, x, attn_mask)
|
||||
else:
|
||||
x = blk(x, attn_mask=attn_mask)
|
||||
if self.downsample is not None:
|
||||
x, _ = self.downsample(x, self.input_resolution)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.embed_dims}, \
|
||||
input_resolution={self.input_resolution}, depth={self.depth}'
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MixMIMTransformer(BaseBackbone):
|
||||
"""MixMIM backbone.
|
||||
|
||||
A PyTorch implement of : ` MixMIM: Mixed and Masked Image
|
||||
Modeling for Efficient Visual Representation Learning
|
||||
<https://arxiv.org/abs/2205.13137>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): MixMIM architecture. If use string,
|
||||
choose from 'base','large' and 'huge'.
|
||||
If use dict, it should have below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **depths** (int): The number of transformer encoder layers.
|
||||
- **num_heads** (int): The number of heads in attention modules.
|
||||
|
||||
Defaults to 'base'.
|
||||
mlp_ratio (int): The mlp ratio in FFN. Defaults to 4.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to mlp_ratio
|
||||
the most common input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 16.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
window_size (list): The height and width of the window.
|
||||
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
||||
Defaults to True.
|
||||
patch_cfg (dict): Extra config dict for patch embedding.
|
||||
Defaults to an empty dict.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
attn_drop_rate (float): attention drop rate. Defaults to 0.
|
||||
use_checkpoint (bool): Whether use the checkpoint to
|
||||
reduce GPU memory cost.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
arch_zoo = {
|
||||
**dict.fromkeys(
|
||||
['b', 'base'], {
|
||||
'embed_dims': 128,
|
||||
'depths': [2, 2, 18, 2],
|
||||
'num_heads': [4, 8, 16, 32]
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['l', 'large'], {
|
||||
'embed_dims': 192,
|
||||
'depths': [2, 2, 18, 2],
|
||||
'num_heads': [6, 12, 24, 48]
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['h', 'huge'], {
|
||||
'embed_dims': 352,
|
||||
'depths': [2, 2, 18, 2],
|
||||
'num_heads': [11, 22, 44, 88]
|
||||
}),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arch='base',
|
||||
mlp_ratio=4,
|
||||
img_size=224,
|
||||
patch_size=4,
|
||||
in_channels=3,
|
||||
window_size=[14, 14, 14, 7],
|
||||
qkv_bias=True,
|
||||
patch_cfg=dict(),
|
||||
norm_cfg=dict(type='LN'),
|
||||
drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
use_checkpoint=False,
|
||||
init_cfg: Optional[dict] = None,
|
||||
) -> None:
|
||||
super(MixMIMTransformer, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
arch = arch.lower()
|
||||
assert arch in set(self.arch_zoo), \
|
||||
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
||||
self.arch_settings = self.arch_zoo[arch]
|
||||
else:
|
||||
essential_keys = {
|
||||
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
|
||||
}
|
||||
assert isinstance(arch, dict) and essential_keys <= set(arch), \
|
||||
f'Custom arch needs a dict with keys {essential_keys}'
|
||||
self.arch_settings = arch
|
||||
|
||||
self.embed_dims = self.arch_settings['embed_dims']
|
||||
self.depths = self.arch_settings['depths']
|
||||
self.num_heads = self.arch_settings['num_heads']
|
||||
|
||||
self.encoder_stride = 32
|
||||
|
||||
self.num_layers = len(self.depths)
|
||||
self.qkv_bias = qkv_bias
|
||||
self.drop_rate = drop_rate
|
||||
self.attn_drop_rate = attn_drop_rate
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.window_size = window_size
|
||||
|
||||
_patch_cfg = dict(
|
||||
in_channels=in_channels,
|
||||
input_size=img_size,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
norm_cfg=dict(type='LN'),
|
||||
)
|
||||
_patch_cfg.update(patch_cfg)
|
||||
self.patch_embed = PatchEmbed(**_patch_cfg)
|
||||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
|
||||
self.dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, drop_path_rate, sum(self.depths))
|
||||
]
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
self.layers.append(
|
||||
MixMIMLayer(
|
||||
embed_dims=int(self.embed_dims * 2**i_layer),
|
||||
input_resolution=(self.patch_resolution[0] // (2**i_layer),
|
||||
self.patch_resolution[1] //
|
||||
(2**i_layer)),
|
||||
depth=self.depths[i_layer],
|
||||
num_heads=self.num_heads[i_layer],
|
||||
window_size=self.window_size[i_layer],
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
qkv_bias=self.qkv_bias,
|
||||
proj_drop_rate=self.drop_rate,
|
||||
attn_drop_rate=self.attn_drop_rate,
|
||||
drop_path_rate=self.dpr[sum(self.depths[:i_layer]
|
||||
):sum(self.depths[:i_layer +
|
||||
1])],
|
||||
norm_cfg=norm_cfg,
|
||||
downsample=PatchMerging if
|
||||
(i_layer < self.num_layers - 1) else None,
|
||||
use_checkpoint=self.use_checkpoint))
|
||||
|
||||
self.num_features = int(self.embed_dims * 2**(self.num_layers - 1))
|
||||
self.drop_after_pos = nn.Dropout(p=self.drop_rate)
|
||||
|
||||
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
||||
self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
self.absolute_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, self.num_patches, self.embed_dims),
|
||||
requires_grad=False)
|
||||
|
||||
_, self.norm = build_norm_layer(norm_cfg, self.num_features)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x, _ = self.patch_embed(x)
|
||||
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, attn_mask=None)
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.avgpool(x.transpose(1, 2)) # B C 1
|
||||
x = torch.flatten(x, 1)
|
||||
|
||||
return (x, )
|
|
@ -45,3 +45,4 @@ Import:
|
|||
- configs/beitv2/metafile.yml
|
||||
- configs/eva/metafile.yml
|
||||
- configs/revvit/metafile.yml
|
||||
- configs/mixmim/metafile.yml
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from copy import deepcopy
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmcls.models.backbones import MixMIMTransformer
|
||||
|
||||
|
||||
class TestMixMIM(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.cfg = dict(arch='b', drop_rate=0.0, drop_path_rate=0.1)
|
||||
|
||||
def test_structure(self):
|
||||
|
||||
# Test custom arch
|
||||
cfg = deepcopy(self.cfg)
|
||||
|
||||
model = MixMIMTransformer(**cfg)
|
||||
self.assertEqual(model.embed_dims, 128)
|
||||
self.assertEqual(sum(model.depths), 24)
|
||||
self.assertIsNotNone(model.absolute_pos_embed)
|
||||
|
||||
num_heads = [4, 8, 16, 32]
|
||||
for i, layer in enumerate(model.layers):
|
||||
self.assertEqual(layer.blocks[0].num_heads, num_heads[i])
|
||||
self.assertEqual(layer.blocks[0].ffn.feedforward_channels,
|
||||
128 * (2**i) * 4)
|
||||
|
||||
def test_forward(self):
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = MixMIMTransformer(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
averaged_token = outs[-1]
|
||||
self.assertEqual(averaged_token.shape, (1, 1024))
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
|
||||
import mmengine
|
||||
import torch
|
||||
from mmengine.runner import CheckpointLoader
|
||||
|
||||
|
||||
def correct_unfold_reduction_order(x: torch.Tensor):
|
||||
out_channel, in_channel = x.shape
|
||||
x = x.reshape(out_channel, 4, in_channel // 4)
|
||||
x = x[:, [0, 2, 1, 3], :].transpose(1, 2).reshape(out_channel, in_channel)
|
||||
return x
|
||||
|
||||
|
||||
def correct_unfold_norm_order(x):
|
||||
in_channel = x.shape[0]
|
||||
x = x.reshape(4, in_channel // 4)
|
||||
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
|
||||
return x
|
||||
|
||||
|
||||
def convert_mixmim(ckpt):
|
||||
|
||||
new_ckpt = OrderedDict()
|
||||
|
||||
for k, v in list(ckpt.items()):
|
||||
new_v = v
|
||||
|
||||
if k.startswith('patch_embed'):
|
||||
new_k = k.replace('proj', 'projection')
|
||||
|
||||
elif k.startswith('layers'):
|
||||
if 'norm1' in k:
|
||||
new_k = k.replace('norm1', 'ln1')
|
||||
elif 'norm2' in k:
|
||||
new_k = k.replace('norm2', 'ln2')
|
||||
elif 'mlp.fc1' in k:
|
||||
new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
|
||||
elif 'mlp.fc2' in k:
|
||||
new_k = k.replace('mlp.fc2', 'ffn.layers.1')
|
||||
else:
|
||||
new_k = k
|
||||
|
||||
elif k.startswith('norm') or k.startswith('absolute_pos_embed'):
|
||||
new_k = k
|
||||
|
||||
elif k.startswith('head'):
|
||||
new_k = k.replace('head.', 'head.fc.')
|
||||
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
# print(new_k)
|
||||
if not new_k.startswith('head'):
|
||||
new_k = 'backbone.' + new_k
|
||||
|
||||
if 'downsample' in new_k:
|
||||
print('Covert {} in PatchMerging from timm to mmcv format!'.format(
|
||||
new_k))
|
||||
|
||||
if 'reduction' in new_k:
|
||||
new_v = correct_unfold_reduction_order(new_v)
|
||||
elif 'norm' in new_k:
|
||||
new_v = correct_unfold_norm_order(new_v)
|
||||
|
||||
new_ckpt[new_k] = new_v
|
||||
|
||||
return new_ckpt
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert keys in pretrained van models to mmcls style.')
|
||||
parser.add_argument('src', help='src model path or url')
|
||||
# The dst path must be a full path of the new checkpoint.
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
|
||||
|
||||
if 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
weight = convert_mixmim(state_dict)
|
||||
# weight = convert_official_mixmim(state_dict)
|
||||
mmengine.mkdir_or_exist(osp.dirname(args.dst))
|
||||
torch.save(weight, args.dst)
|
||||
|
||||
print('Done!!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue