diff --git a/README.md b/README.md
index 3bdb53c7..95f3c9f5 100644
--- a/README.md
+++ b/README.md
@@ -154,6 +154,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
- [x] [RepLKNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/replknet)
- [x] [BEiT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beit) / [BEiT v2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beitv2)
- [x] [EVA](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/eva)
+- [x] [MixMIM](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mixmim)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 0ebde4ae..5fddca0e 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -155,6 +155,7 @@ mim install -e .
- [x] [RepLKNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/replknet)
- [x] [BEiT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beit) / [BEiT v2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beitv2)
- [x] [EVA](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/eva)
+- [x] [MixMIM](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mixmim)
diff --git a/configs/_base_/models/mixmim/mixmim_base.py b/configs/_base_/models/mixmim/mixmim_base.py
new file mode 100644
index 00000000..ccde3575
--- /dev/null
+++ b/configs/_base_/models/mixmim/mixmim_base.py
@@ -0,0 +1,20 @@
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='MixMIMTransformer', arch='B', drop_rate=0.0, drop_path_rate=0.1),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=1024,
+ init_cfg=None,
+ loss=dict(
+ type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
+ cal_acc=False),
+ init_cfg=[
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
+ ],
+ train_cfg=dict(augments=[
+ dict(type='Mixup', alpha=0.8),
+ dict(type='CutMix', alpha=1.0)
+ ]))
diff --git a/configs/mixmim/README.md b/configs/mixmim/README.md
new file mode 100644
index 00000000..bcba223d
--- /dev/null
+++ b/configs/mixmim/README.md
@@ -0,0 +1,90 @@
+# MixMIM
+
+> [MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning](https://arxiv.org/abs/2205.13137)
+
+
+
+## Abstract
+
+In this study, we propose Mixed and Masked Image Modeling (MixMIM), a
+simple but efficient MIM method that is applicable to various hierarchical Vision
+Transformers. Existing MIM methods replace a random subset of input tokens with
+a special [MASK] symbol and aim at reconstructing original image tokens from
+the corrupted image. However, we find that using the [MASK] symbol greatly
+slows down the training and causes training-finetuning inconsistency, due to the
+large masking ratio (e.g., 40% in BEiT). In contrast, we replace the masked tokens
+of one image with visible tokens of another image, i.e., creating a mixed image.
+We then conduct dual reconstruction to reconstruct the original two images from
+the mixed input, which significantly improves efficiency. While MixMIM can
+be applied to various architectures, this paper explores a simpler but stronger
+hierarchical Transformer, and scales with MixMIM-B, -L, and -H. Empirical
+results demonstrate that MixMIM can learn high-quality visual representations
+efficiently. Notably, MixMIM-B with 88M parameters achieves 85.1% top-1
+accuracy on ImageNet-1K by pretraining for 600 epochs, setting a new record for
+neural networks with comparable model sizes (e.g., ViT-B) among MIM methods.
+Besides, its transferring performances on the other 6 datasets show MixMIM has
+better FLOPs / performance tradeoff than previous MIM methods
+
+
+

+
+
+## How to use it?
+
+### Inference
+
+
+
+**Predict image**
+
+```python
+>>> import torch
+>>> import mmcls
+>>> model = mmcls.get_model('mixmim-base_3rdparty_in1k', pretrained=True)
+>>> predict = mmcls.inference_model(model, 'demo/demo.JPEG')
+>>> print(predict['pred_class'])
+sea snake
+>>> print(predict['pred_score'])
+0.865431010723114
+```
+
+**Use the model**
+
+```python
+>>> import torch
+>>> import mmcls
+>>>
+>>> model = mmcls.get_model('mixmim-base_3rdparty_in1k', pretrained=True)
+>>> inputs = torch.rand(1, 3, 224, 224)
+>>> # To get classification scores.
+>>> out = model(inputs)
+>>> print(out.shape)
+torch.Size([1, 1000])
+>>> # To extract features.
+>>> outs = model.extract_feat(inputs)
+>>> print(outs[0].shape)
+torch.Size([1, 1024])
+```
+
+
+
+## Models
+
+| Model | Params(M) | Pretrain Epochs | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
+| :-------------------------: | :-------: | :-------------: | :------: | :-------: | :-------: | :-----------------------------------: | :------------------------------------------------------------------------------------: |
+| mixmim-base_3rdparty_in1k\* | 88 | 300 | 16.3 | 84.6 | 97.0 | [config](./mixmim-base_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mixmim/mixmim-base_3rdparty_in1k_20221206-e40e2c8c.pth) |
+
+*Models with * are converted from the [official repo](https://github.com/Sense-X/MixMIM). The config files of these models are only for inference.*
+
+For MixMIM self-supervised learning algorithm, welcome to [MMSelfSup page](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/mixmim) to get more information.
+
+## Citation
+
+```bibtex
+@article{MixMIM2022,
+ author = {Jihao Liu, Xin Huang, Yu Liu, Hongsheng Li},
+ journal = {arXiv:2205.13137},
+ title = {MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning},
+ year = {2022},
+}
+```
diff --git a/configs/mixmim/metafile.yml b/configs/mixmim/metafile.yml
new file mode 100644
index 00000000..70623c8c
--- /dev/null
+++ b/configs/mixmim/metafile.yml
@@ -0,0 +1,39 @@
+Collections:
+ - Name: MixMIM
+ Metadata:
+ Architecture:
+ - Attention Dropout
+ - Convolution
+ - Dense Connections
+ - Dropout
+ - GELU
+ - Layer Normalization
+ - Multi-Head Attention
+ - Scaled Dot-Product Attention
+ - Tanh Activation
+ Paper:
+ Title: 'MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning'
+ URL: https://arxiv.org/abs/2205.13137
+ README: configs/mixmim/README.md
+ Code:
+ URL: https://github.com/open-mmlab/mmclassification/blob/dev-1.x/mmcls/models/backbones/mixmim.py
+ Version: v1.0.0rc4
+
+Models:
+ - Name: mixmim-base_3rdparty_in1k
+ Metadata:
+ FLOPs: 16352000000
+ Parameters: 88344000
+ Training Data:
+ - ImageNet-1k
+ In Collection: MixMIM
+ Results:
+ - Dataset: ImageNet-1k
+ Task: Image Classification
+ Metrics:
+ Top 1 Accuracy: 84.6
+ Top 5 Accuracy: 97.0
+ Weights: https://download.openmmlab.com/mmclassification/v0/mixmim/mixmim-base_3rdparty_in1k_20221206-e40e2c8c.pth
+ Config: configs/mixmim/mixmim-base_8xb64_in1k.py
+ Converted From:
+ Code: https://github.com/Sense-X/MixMIM
diff --git a/configs/mixmim/mixmim-base_8xb64_in1k.py b/configs/mixmim/mixmim-base_8xb64_in1k.py
new file mode 100644
index 00000000..bb35a037
--- /dev/null
+++ b/configs/mixmim/mixmim-base_8xb64_in1k.py
@@ -0,0 +1,5 @@
+_base_ = [
+ '../_base_/models/mixmim/mixmim_base.py',
+ '../_base_/datasets/imagenet_bs64_swin_224.py',
+ '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
+]
diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py
index 458741fb..b583d988 100644
--- a/mmcls/models/backbones/__init__.py
+++ b/mmcls/models/backbones/__init__.py
@@ -16,6 +16,7 @@ from .hornet import HorNet
from .hrnet import HRNet
from .inception_v3 import InceptionV3
from .lenet import LeNet5
+from .mixmim import MixMIMTransformer
from .mlp_mixer import MlpMixer
from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3
@@ -102,5 +103,6 @@ __all__ = [
'DaViT',
'BEiT',
'RevVisionTransformer',
+ 'MixMIMTransformer',
'TinyViT',
]
diff --git a/mmcls/models/backbones/mixmim.py b/mmcls/models/backbones/mixmim.py
new file mode 100644
index 00000000..6bed2cf4
--- /dev/null
+++ b/mmcls/models/backbones/mixmim.py
@@ -0,0 +1,494 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Union
+
+import torch
+from mmcv.cnn import build_norm_layer
+from mmcv.cnn.bricks.drop import DropPath
+from mmcv.cnn.bricks.transformer import PatchEmbed, PatchMerging
+from mmengine.model import BaseModule
+from torch import nn
+from torch.utils.checkpoint import checkpoint
+
+from mmcls.models.backbones.base_backbone import BaseBackbone
+from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
+from mmcls.models.utils.attention import WindowMSA
+from mmcls.models.utils.helpers import to_2tuple
+from mmcls.registry import MODELS
+
+
+class MixMIMWindowAttention(WindowMSA):
+ """MixMIM Window Attention.
+
+ Compared with WindowMSA, we add some modifications
+ in ``forward`` to meet the requirement of MixMIM during
+ pretraining.
+
+ Implements one windown attention in MixMIM.
+ Args:
+ embed_dims (int): The feature dimension.
+ window_size (list): The height and width of the window.
+ num_heads (int): The number of head in attention.
+ qkv_bias (bool): Whether to add bias for qkv in attention modules.
+ Defaults to True.
+ qk_scale (float, optional): Override default qk scale of
+ ``head_dim ** -0.5`` if set. Defaults to None.
+ attn_drop_rate (float): attention drop rate.
+ Defaults to 0.
+ proj_drop_rate (float): Probability of an element to be zeroed.
+ Defaults to 0.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ window_size,
+ num_heads,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop_rate=0.,
+ proj_drop_rate=0.,
+ init_cfg=None):
+
+ super().__init__(
+ embed_dims=embed_dims,
+ window_size=window_size,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop_rate,
+ proj_drop=proj_drop_rate,
+ init_cfg=init_cfg)
+
+ def forward(self, x, mask=None):
+
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[
+ 2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1],
+ self.window_size[0] * self.window_size[1],
+ -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ mask = mask.reshape(B_, 1, 1, N)
+ mask_new = mask * mask.transpose(
+ 2, 3) + (1 - mask) * (1 - mask).transpose(2, 3)
+ mask_new = 1 - mask_new
+
+ if mask_new.dtype == torch.float16:
+ attn = attn - 65500 * mask_new
+ else:
+ attn = attn - 1e30 * mask_new
+
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MixMIMBlock(TransformerEncoderLayer):
+ """MixMIM Block. Implements one block in MixMIM.
+
+ Args:
+ embed_dims (int): The feature dimension.
+ input_resolution (tuple): Input resolution of this layer.
+ num_heads (int): The number of head in attention,
+ window_size (list): The height and width of the window.
+ mlp_ratio (int): The MLP ration in FFN.
+ num_fcs (int): The number of linear layers in a block.
+ qkv_bias (bool): Whether to add bias for qkv in attention modules.
+ Defaults to True.
+ proj_drop_rate (float): Probability of an element to be zeroed.
+ Defaults to 0.
+ attn_drop_rate (float): attention drop rate.
+ Defaults to 0.
+ drop_path_rate (float): stochastic depth rate.
+ Defaults to 0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ input_resolution,
+ num_heads,
+ window_size=7,
+ mlp_ratio=4.,
+ num_fcs=2,
+ qkv_bias=True,
+ proj_drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
+
+ super().__init__(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ feedforward_channels=int(mlp_ratio * embed_dims),
+ drop_rate=proj_drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=drop_path_rate,
+ num_fcs=num_fcs,
+ qkv_bias=qkv_bias,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ init_cfg=init_cfg)
+
+ self.embed_dims = embed_dims
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.mlp_ratio = mlp_ratio
+
+ if min(self.input_resolution) <= self.window_size:
+ self.window_size = min(self.input_resolution)
+
+ self.attn = MixMIMWindowAttention(
+ embed_dims=embed_dims,
+ window_size=to_2tuple(self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop_rate=attn_drop_rate,
+ proj_drop_rate=proj_drop_rate)
+
+ self.drop_path = DropPath(
+ drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ @staticmethod
+ def window_reverse(windows, H, W, window_size):
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size,
+ window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+ @staticmethod
+ def window_partition(x, window_size):
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size,
+ window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
+ windows = windows.view(-1, window_size, window_size, C)
+ return windows
+
+ def forward(self, x, attn_mask=None):
+ H, W = self.input_resolution
+ B, L, C = x.shape
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # partition windows
+ x_windows = self.window_partition(
+ x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
+ C) # nW*B, window_size*window_size, C
+ if attn_mask is not None:
+ attn_mask = attn_mask.repeat(B, 1, 1) # B, N, 1
+ attn_mask = attn_mask.view(B, H, W, 1)
+ attn_mask = self.window_partition(attn_mask, self.window_size)
+ attn_mask = attn_mask.view(-1, self.window_size * self.window_size,
+ 1)
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(
+ x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size,
+ self.window_size, C)
+ x = self.window_reverse(attn_windows, H, W,
+ self.window_size) # B H' W' C
+
+ x = x.view(B, H * W, C)
+
+ x = shortcut + self.drop_path(x)
+
+ x = self.ffn(self.norm2(x), identity=x) # ffn contains DropPath
+
+ return x
+
+
+class MixMIMLayer(BaseModule):
+ """Implements one MixMIM layer, which may contains several MixMIM blocks.
+
+ Args:
+ embed_dims (int): The feature dimension.
+ input_resolution (tuple): Input resolution of this layer.
+ depth (int): The number of blocks in this layer.
+ num_heads (int): The number of head in attention,
+ window_size (list): The height and width of the window.
+ mlp_ratio (int): The MLP ration in FFN.
+ qkv_bias (bool): Whether to add bias for qkv in attention modules.
+ Defaults to True.
+ proj_drop_rate (float): Probability of an element to be zeroed.
+ Defaults to 0.
+ attn_drop_rate (float): attention drop rate.
+ Defaults to 0.
+ drop_path_rate (float): stochastic depth rate.
+ Defaults to 0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ downsample (class, optional): Downsample the output of blocks b
+ y patch merging.Defaults to None.
+ use_checkpoint (bool): Whether use the checkpoint to
+ reduce GPU memory cost.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dims: int,
+ input_resolution: int,
+ depth: int,
+ num_heads: int,
+ window_size: int,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ proj_drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=[0.],
+ norm_cfg=dict(type='LN'),
+ downsample=None,
+ use_checkpoint=False,
+ init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
+ super().__init__(init_cfg=init_cfg)
+ self.embed_dims = embed_dims
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList()
+ for i in range(depth):
+ self.blocks.append(
+ MixMIMBlock(
+ embed_dims=embed_dims,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_drop_rate=proj_drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=drop_path_rate[i],
+ norm_cfg=norm_cfg))
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(
+ in_channels=embed_dims,
+ out_channels=2 * embed_dims,
+ norm_cfg=norm_cfg)
+ else:
+ self.downsample = None
+
+ def forward(self, x, attn_mask=None):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint(blk, x, attn_mask)
+ else:
+ x = blk(x, attn_mask=attn_mask)
+ if self.downsample is not None:
+ x, _ = self.downsample(x, self.input_resolution)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.embed_dims}, \
+ input_resolution={self.input_resolution}, depth={self.depth}'
+
+
+@MODELS.register_module()
+class MixMIMTransformer(BaseBackbone):
+ """MixMIM backbone.
+
+ A PyTorch implement of : ` MixMIM: Mixed and Masked Image
+ Modeling for Efficient Visual Representation Learning
+ `_
+
+ Args:
+ arch (str | dict): MixMIM architecture. If use string,
+ choose from 'base','large' and 'huge'.
+ If use dict, it should have below keys:
+
+ - **embed_dims** (int): The dimensions of embedding.
+ - **depths** (int): The number of transformer encoder layers.
+ - **num_heads** (int): The number of heads in attention modules.
+
+ Defaults to 'base'.
+ mlp_ratio (int): The mlp ratio in FFN. Defaults to 4.
+ img_size (int | tuple): The expected input image shape. Because we
+ support dynamic input shape, just set the argument to mlp_ratio
+ the most common input image shape. Defaults to 224.
+ patch_size (int | tuple): The patch size in patch embedding.
+ Defaults to 16.
+ in_channels (int): The num of input channels. Defaults to 3.
+ window_size (list): The height and width of the window.
+ qkv_bias (bool): Whether to add bias for qkv in attention modules.
+ Defaults to True.
+ patch_cfg (dict): Extra config dict for patch embedding.
+ Defaults to an empty dict.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='LN')``.
+ drop_rate (float): Probability of an element to be zeroed.
+ Defaults to 0.
+ drop_path_rate (float): stochastic depth rate. Defaults to 0.
+ attn_drop_rate (float): attention drop rate. Defaults to 0.
+ use_checkpoint (bool): Whether use the checkpoint to
+ reduce GPU memory cost.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+ """
+ arch_zoo = {
+ **dict.fromkeys(
+ ['b', 'base'], {
+ 'embed_dims': 128,
+ 'depths': [2, 2, 18, 2],
+ 'num_heads': [4, 8, 16, 32]
+ }),
+ **dict.fromkeys(
+ ['l', 'large'], {
+ 'embed_dims': 192,
+ 'depths': [2, 2, 18, 2],
+ 'num_heads': [6, 12, 24, 48]
+ }),
+ **dict.fromkeys(
+ ['h', 'huge'], {
+ 'embed_dims': 352,
+ 'depths': [2, 2, 18, 2],
+ 'num_heads': [11, 22, 44, 88]
+ }),
+ }
+
+ def __init__(
+ self,
+ arch='base',
+ mlp_ratio=4,
+ img_size=224,
+ patch_size=4,
+ in_channels=3,
+ window_size=[14, 14, 14, 7],
+ qkv_bias=True,
+ patch_cfg=dict(),
+ norm_cfg=dict(type='LN'),
+ drop_rate=0.0,
+ drop_path_rate=0.0,
+ attn_drop_rate=0.0,
+ use_checkpoint=False,
+ init_cfg: Optional[dict] = None,
+ ) -> None:
+ super(MixMIMTransformer, self).__init__(init_cfg=init_cfg)
+
+ if isinstance(arch, str):
+ arch = arch.lower()
+ assert arch in set(self.arch_zoo), \
+ f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
+ self.arch_settings = self.arch_zoo[arch]
+ else:
+ essential_keys = {
+ 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
+ }
+ assert isinstance(arch, dict) and essential_keys <= set(arch), \
+ f'Custom arch needs a dict with keys {essential_keys}'
+ self.arch_settings = arch
+
+ self.embed_dims = self.arch_settings['embed_dims']
+ self.depths = self.arch_settings['depths']
+ self.num_heads = self.arch_settings['num_heads']
+
+ self.encoder_stride = 32
+
+ self.num_layers = len(self.depths)
+ self.qkv_bias = qkv_bias
+ self.drop_rate = drop_rate
+ self.attn_drop_rate = attn_drop_rate
+ self.use_checkpoint = use_checkpoint
+ self.mlp_ratio = mlp_ratio
+ self.window_size = window_size
+
+ _patch_cfg = dict(
+ in_channels=in_channels,
+ input_size=img_size,
+ embed_dims=self.embed_dims,
+ conv_type='Conv2d',
+ kernel_size=patch_size,
+ stride=patch_size,
+ norm_cfg=dict(type='LN'),
+ )
+ _patch_cfg.update(patch_cfg)
+ self.patch_embed = PatchEmbed(**_patch_cfg)
+ self.patch_resolution = self.patch_embed.init_out_size
+
+ self.dpr = [
+ x.item()
+ for x in torch.linspace(0, drop_path_rate, sum(self.depths))
+ ]
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ self.layers.append(
+ MixMIMLayer(
+ embed_dims=int(self.embed_dims * 2**i_layer),
+ input_resolution=(self.patch_resolution[0] // (2**i_layer),
+ self.patch_resolution[1] //
+ (2**i_layer)),
+ depth=self.depths[i_layer],
+ num_heads=self.num_heads[i_layer],
+ window_size=self.window_size[i_layer],
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=self.qkv_bias,
+ proj_drop_rate=self.drop_rate,
+ attn_drop_rate=self.attn_drop_rate,
+ drop_path_rate=self.dpr[sum(self.depths[:i_layer]
+ ):sum(self.depths[:i_layer +
+ 1])],
+ norm_cfg=norm_cfg,
+ downsample=PatchMerging if
+ (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=self.use_checkpoint))
+
+ self.num_features = int(self.embed_dims * 2**(self.num_layers - 1))
+ self.drop_after_pos = nn.Dropout(p=self.drop_rate)
+
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
+ self.absolute_pos_embed = nn.Parameter(
+ torch.zeros(1, self.num_patches, self.embed_dims),
+ requires_grad=False)
+
+ _, self.norm = build_norm_layer(norm_cfg, self.num_features)
+
+ def forward(self, x: torch.Tensor):
+ x, _ = self.patch_embed(x)
+
+ x = x + self.absolute_pos_embed
+ x = self.drop_after_pos(x)
+
+ for layer in self.layers:
+ x = layer(x, attn_mask=None)
+
+ x = self.norm(x)
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
+ x = torch.flatten(x, 1)
+
+ return (x, )
diff --git a/model-index.yml b/model-index.yml
index d0bbf424..a761ab8a 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -45,3 +45,4 @@ Import:
- configs/beitv2/metafile.yml
- configs/eva/metafile.yml
- configs/revvit/metafile.yml
+ - configs/mixmim/metafile.yml
diff --git a/tests/test_models/test_backbones/test_mixmim.py b/tests/test_models/test_backbones/test_mixmim.py
new file mode 100644
index 00000000..e21d143c
--- /dev/null
+++ b/tests/test_models/test_backbones/test_mixmim.py
@@ -0,0 +1,40 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+from unittest import TestCase
+
+import torch
+
+from mmcls.models.backbones import MixMIMTransformer
+
+
+class TestMixMIM(TestCase):
+
+ def setUp(self):
+ self.cfg = dict(arch='b', drop_rate=0.0, drop_path_rate=0.1)
+
+ def test_structure(self):
+
+ # Test custom arch
+ cfg = deepcopy(self.cfg)
+
+ model = MixMIMTransformer(**cfg)
+ self.assertEqual(model.embed_dims, 128)
+ self.assertEqual(sum(model.depths), 24)
+ self.assertIsNotNone(model.absolute_pos_embed)
+
+ num_heads = [4, 8, 16, 32]
+ for i, layer in enumerate(model.layers):
+ self.assertEqual(layer.blocks[0].num_heads, num_heads[i])
+ self.assertEqual(layer.blocks[0].ffn.feedforward_channels,
+ 128 * (2**i) * 4)
+
+ def test_forward(self):
+ imgs = torch.randn(1, 3, 224, 224)
+
+ cfg = deepcopy(self.cfg)
+ model = MixMIMTransformer(**cfg)
+ outs = model(imgs)
+ self.assertIsInstance(outs, tuple)
+ self.assertEqual(len(outs), 1)
+ averaged_token = outs[-1]
+ self.assertEqual(averaged_token.shape, (1, 1024))
diff --git a/tools/model_converters/mixmimx_to_mmcls.py b/tools/model_converters/mixmimx_to_mmcls.py
new file mode 100644
index 00000000..dcf9858b
--- /dev/null
+++ b/tools/model_converters/mixmimx_to_mmcls.py
@@ -0,0 +1,98 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import os.path as osp
+from collections import OrderedDict
+
+import mmengine
+import torch
+from mmengine.runner import CheckpointLoader
+
+
+def correct_unfold_reduction_order(x: torch.Tensor):
+ out_channel, in_channel = x.shape
+ x = x.reshape(out_channel, 4, in_channel // 4)
+ x = x[:, [0, 2, 1, 3], :].transpose(1, 2).reshape(out_channel, in_channel)
+ return x
+
+
+def correct_unfold_norm_order(x):
+ in_channel = x.shape[0]
+ x = x.reshape(4, in_channel // 4)
+ x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
+ return x
+
+
+def convert_mixmim(ckpt):
+
+ new_ckpt = OrderedDict()
+
+ for k, v in list(ckpt.items()):
+ new_v = v
+
+ if k.startswith('patch_embed'):
+ new_k = k.replace('proj', 'projection')
+
+ elif k.startswith('layers'):
+ if 'norm1' in k:
+ new_k = k.replace('norm1', 'ln1')
+ elif 'norm2' in k:
+ new_k = k.replace('norm2', 'ln2')
+ elif 'mlp.fc1' in k:
+ new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
+ elif 'mlp.fc2' in k:
+ new_k = k.replace('mlp.fc2', 'ffn.layers.1')
+ else:
+ new_k = k
+
+ elif k.startswith('norm') or k.startswith('absolute_pos_embed'):
+ new_k = k
+
+ elif k.startswith('head'):
+ new_k = k.replace('head.', 'head.fc.')
+
+ else:
+ raise ValueError
+
+ # print(new_k)
+ if not new_k.startswith('head'):
+ new_k = 'backbone.' + new_k
+
+ if 'downsample' in new_k:
+ print('Covert {} in PatchMerging from timm to mmcv format!'.format(
+ new_k))
+
+ if 'reduction' in new_k:
+ new_v = correct_unfold_reduction_order(new_v)
+ elif 'norm' in new_k:
+ new_v = correct_unfold_norm_order(new_v)
+
+ new_ckpt[new_k] = new_v
+
+ return new_ckpt
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Convert keys in pretrained van models to mmcls style.')
+ parser.add_argument('src', help='src model path or url')
+ # The dst path must be a full path of the new checkpoint.
+ parser.add_argument('dst', help='save path')
+ args = parser.parse_args()
+
+ checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
+
+ if 'model' in checkpoint:
+ state_dict = checkpoint['model']
+ else:
+ state_dict = checkpoint
+
+ weight = convert_mixmim(state_dict)
+ # weight = convert_official_mixmim(state_dict)
+ mmengine.mkdir_or_exist(osp.dirname(args.dst))
+ torch.save(weight, args.dst)
+
+ print('Done!!')
+
+
+if __name__ == '__main__':
+ main()