diff --git a/README.md b/README.md index 3bdb53c7..95f3c9f5 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea - [x] [RepLKNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/replknet) - [x] [BEiT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beit) / [BEiT v2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beitv2) - [x] [EVA](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/eva) +- [x] [MixMIM](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mixmim) diff --git a/README_zh-CN.md b/README_zh-CN.md index 0ebde4ae..5fddca0e 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -155,6 +155,7 @@ mim install -e . - [x] [RepLKNet](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/replknet) - [x] [BEiT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beit) / [BEiT v2](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/beitv2) - [x] [EVA](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/eva) +- [x] [MixMIM](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mixmim) diff --git a/configs/_base_/models/mixmim/mixmim_base.py b/configs/_base_/models/mixmim/mixmim_base.py new file mode 100644 index 00000000..ccde3575 --- /dev/null +++ b/configs/_base_/models/mixmim/mixmim_base.py @@ -0,0 +1,20 @@ +model = dict( + type='ImageClassifier', + backbone=dict( + type='MixMIMTransformer', arch='B', drop_rate=0.0, drop_path_rate=0.1), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=1024, + init_cfg=None, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + cal_acc=False), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0) + ])) diff --git a/configs/mixmim/README.md b/configs/mixmim/README.md new file mode 100644 index 00000000..bcba223d --- /dev/null +++ b/configs/mixmim/README.md @@ -0,0 +1,90 @@ +# MixMIM + +> [MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning](https://arxiv.org/abs/2205.13137) + + + +## Abstract + +In this study, we propose Mixed and Masked Image Modeling (MixMIM), a +simple but efficient MIM method that is applicable to various hierarchical Vision +Transformers. Existing MIM methods replace a random subset of input tokens with +a special [MASK] symbol and aim at reconstructing original image tokens from +the corrupted image. However, we find that using the [MASK] symbol greatly +slows down the training and causes training-finetuning inconsistency, due to the +large masking ratio (e.g., 40% in BEiT). In contrast, we replace the masked tokens +of one image with visible tokens of another image, i.e., creating a mixed image. +We then conduct dual reconstruction to reconstruct the original two images from +the mixed input, which significantly improves efficiency. While MixMIM can +be applied to various architectures, this paper explores a simpler but stronger +hierarchical Transformer, and scales with MixMIM-B, -L, and -H. Empirical +results demonstrate that MixMIM can learn high-quality visual representations +efficiently. Notably, MixMIM-B with 88M parameters achieves 85.1% top-1 +accuracy on ImageNet-1K by pretraining for 600 epochs, setting a new record for +neural networks with comparable model sizes (e.g., ViT-B) among MIM methods. +Besides, its transferring performances on the other 6 datasets show MixMIM has +better FLOPs / performance tradeoff than previous MIM methods + +
+ +
+ +## How to use it? + +### Inference + + + +**Predict image** + +```python +>>> import torch +>>> import mmcls +>>> model = mmcls.get_model('mixmim-base_3rdparty_in1k', pretrained=True) +>>> predict = mmcls.inference_model(model, 'demo/demo.JPEG') +>>> print(predict['pred_class']) +sea snake +>>> print(predict['pred_score']) +0.865431010723114 +``` + +**Use the model** + +```python +>>> import torch +>>> import mmcls +>>> +>>> model = mmcls.get_model('mixmim-base_3rdparty_in1k', pretrained=True) +>>> inputs = torch.rand(1, 3, 224, 224) +>>> # To get classification scores. +>>> out = model(inputs) +>>> print(out.shape) +torch.Size([1, 1000]) +>>> # To extract features. +>>> outs = model.extract_feat(inputs) +>>> print(outs[0].shape) +torch.Size([1, 1024]) +``` + + + +## Models + +| Model | Params(M) | Pretrain Epochs | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | +| :-------------------------: | :-------: | :-------------: | :------: | :-------: | :-------: | :-----------------------------------: | :------------------------------------------------------------------------------------: | +| mixmim-base_3rdparty_in1k\* | 88 | 300 | 16.3 | 84.6 | 97.0 | [config](./mixmim-base_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mixmim/mixmim-base_3rdparty_in1k_20221206-e40e2c8c.pth) | + +*Models with * are converted from the [official repo](https://github.com/Sense-X/MixMIM). The config files of these models are only for inference.* + +For MixMIM self-supervised learning algorithm, welcome to [MMSelfSup page](https://github.com/open-mmlab/mmselfsup/tree/dev-1.x/configs/selfsup/mixmim) to get more information. + +## Citation + +```bibtex +@article{MixMIM2022, + author = {Jihao Liu, Xin Huang, Yu Liu, Hongsheng Li}, + journal = {arXiv:2205.13137}, + title = {MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning}, + year = {2022}, +} +``` diff --git a/configs/mixmim/metafile.yml b/configs/mixmim/metafile.yml new file mode 100644 index 00000000..70623c8c --- /dev/null +++ b/configs/mixmim/metafile.yml @@ -0,0 +1,39 @@ +Collections: + - Name: MixMIM + Metadata: + Architecture: + - Attention Dropout + - Convolution + - Dense Connections + - Dropout + - GELU + - Layer Normalization + - Multi-Head Attention + - Scaled Dot-Product Attention + - Tanh Activation + Paper: + Title: 'MixMIM: Mixed and Masked Image Modeling for Efficient Visual Representation Learning' + URL: https://arxiv.org/abs/2205.13137 + README: configs/mixmim/README.md + Code: + URL: https://github.com/open-mmlab/mmclassification/blob/dev-1.x/mmcls/models/backbones/mixmim.py + Version: v1.0.0rc4 + +Models: + - Name: mixmim-base_3rdparty_in1k + Metadata: + FLOPs: 16352000000 + Parameters: 88344000 + Training Data: + - ImageNet-1k + In Collection: MixMIM + Results: + - Dataset: ImageNet-1k + Task: Image Classification + Metrics: + Top 1 Accuracy: 84.6 + Top 5 Accuracy: 97.0 + Weights: https://download.openmmlab.com/mmclassification/v0/mixmim/mixmim-base_3rdparty_in1k_20221206-e40e2c8c.pth + Config: configs/mixmim/mixmim-base_8xb64_in1k.py + Converted From: + Code: https://github.com/Sense-X/MixMIM diff --git a/configs/mixmim/mixmim-base_8xb64_in1k.py b/configs/mixmim/mixmim-base_8xb64_in1k.py new file mode 100644 index 00000000..bb35a037 --- /dev/null +++ b/configs/mixmim/mixmim-base_8xb64_in1k.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/mixmim/mixmim_base.py', + '../_base_/datasets/imagenet_bs64_swin_224.py', + '../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py' +] diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index 458741fb..b583d988 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -16,6 +16,7 @@ from .hornet import HorNet from .hrnet import HRNet from .inception_v3 import InceptionV3 from .lenet import LeNet5 +from .mixmim import MixMIMTransformer from .mlp_mixer import MlpMixer from .mobilenet_v2 import MobileNetV2 from .mobilenet_v3 import MobileNetV3 @@ -102,5 +103,6 @@ __all__ = [ 'DaViT', 'BEiT', 'RevVisionTransformer', + 'MixMIMTransformer', 'TinyViT', ] diff --git a/mmcls/models/backbones/mixmim.py b/mmcls/models/backbones/mixmim.py new file mode 100644 index 00000000..6bed2cf4 --- /dev/null +++ b/mmcls/models/backbones/mixmim.py @@ -0,0 +1,494 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Union + +import torch +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.drop import DropPath +from mmcv.cnn.bricks.transformer import PatchEmbed, PatchMerging +from mmengine.model import BaseModule +from torch import nn +from torch.utils.checkpoint import checkpoint + +from mmcls.models.backbones.base_backbone import BaseBackbone +from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer +from mmcls.models.utils.attention import WindowMSA +from mmcls.models.utils.helpers import to_2tuple +from mmcls.registry import MODELS + + +class MixMIMWindowAttention(WindowMSA): + """MixMIM Window Attention. + + Compared with WindowMSA, we add some modifications + in ``forward`` to meet the requirement of MixMIM during + pretraining. + + Implements one windown attention in MixMIM. + Args: + embed_dims (int): The feature dimension. + window_size (list): The height and width of the window. + num_heads (int): The number of head in attention. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + qk_scale (float, optional): Override default qk scale of + ``head_dim ** -0.5`` if set. Defaults to None. + attn_drop_rate (float): attention drop rate. + Defaults to 0. + proj_drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop_rate=0., + proj_drop_rate=0., + init_cfg=None): + + super().__init__( + embed_dims=embed_dims, + window_size=window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop_rate, + proj_drop=proj_drop_rate, + init_cfg=init_cfg) + + def forward(self, x, mask=None): + + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + mask = mask.reshape(B_, 1, 1, N) + mask_new = mask * mask.transpose( + 2, 3) + (1 - mask) * (1 - mask).transpose(2, 3) + mask_new = 1 - mask_new + + if mask_new.dtype == torch.float16: + attn = attn - 65500 * mask_new + else: + attn = attn - 1e30 * mask_new + + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MixMIMBlock(TransformerEncoderLayer): + """MixMIM Block. Implements one block in MixMIM. + + Args: + embed_dims (int): The feature dimension. + input_resolution (tuple): Input resolution of this layer. + num_heads (int): The number of head in attention, + window_size (list): The height and width of the window. + mlp_ratio (int): The MLP ration in FFN. + num_fcs (int): The number of linear layers in a block. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + proj_drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + attn_drop_rate (float): attention drop rate. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. + Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims, + input_resolution, + num_heads, + window_size=7, + mlp_ratio=4., + num_fcs=2, + qkv_bias=True, + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='LN'), + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + + super().__init__( + embed_dims=embed_dims, + num_heads=num_heads, + feedforward_channels=int(mlp_ratio * embed_dims), + drop_rate=proj_drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + num_fcs=num_fcs, + qkv_bias=qkv_bias, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + if min(self.input_resolution) <= self.window_size: + self.window_size = min(self.input_resolution) + + self.attn = MixMIMWindowAttention( + embed_dims=embed_dims, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop_rate=attn_drop_rate, + proj_drop_rate=proj_drop_rate) + + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + @staticmethod + def window_reverse(windows, H, W, window_size): + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + @staticmethod + def window_partition(x, window_size): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, + window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() + windows = windows.view(-1, window_size, window_size, C) + return windows + + def forward(self, x, attn_mask=None): + H, W = self.input_resolution + B, L, C = x.shape + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # partition windows + x_windows = self.window_partition( + x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C + if attn_mask is not None: + attn_mask = attn_mask.repeat(B, 1, 1) # B, N, 1 + attn_mask = attn_mask.view(B, H, W, 1) + attn_mask = self.window_partition(attn_mask, self.window_size) + attn_mask = attn_mask.view(-1, self.window_size * self.window_size, + 1) + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + x = self.window_reverse(attn_windows, H, W, + self.window_size) # B H' W' C + + x = x.view(B, H * W, C) + + x = shortcut + self.drop_path(x) + + x = self.ffn(self.norm2(x), identity=x) # ffn contains DropPath + + return x + + +class MixMIMLayer(BaseModule): + """Implements one MixMIM layer, which may contains several MixMIM blocks. + + Args: + embed_dims (int): The feature dimension. + input_resolution (tuple): Input resolution of this layer. + depth (int): The number of blocks in this layer. + num_heads (int): The number of head in attention, + window_size (list): The height and width of the window. + mlp_ratio (int): The MLP ration in FFN. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + proj_drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + attn_drop_rate (float): attention drop rate. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. + Defaults to 0. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + downsample (class, optional): Downsample the output of blocks b + y patch merging.Defaults to None. + use_checkpoint (bool): Whether use the checkpoint to + reduce GPU memory cost. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + embed_dims: int, + input_resolution: int, + depth: int, + num_heads: int, + window_size: int, + mlp_ratio=4., + qkv_bias=True, + proj_drop_rate=0., + attn_drop_rate=0., + drop_path_rate=[0.], + norm_cfg=dict(type='LN'), + downsample=None, + use_checkpoint=False, + init_cfg: Optional[Union[List[dict], dict]] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.embed_dims = embed_dims + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList() + for i in range(depth): + self.blocks.append( + MixMIMBlock( + embed_dims=embed_dims, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop_rate=proj_drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate[i], + norm_cfg=norm_cfg)) + # patch merging layer + if downsample is not None: + self.downsample = downsample( + in_channels=embed_dims, + out_channels=2 * embed_dims, + norm_cfg=norm_cfg) + else: + self.downsample = None + + def forward(self, x, attn_mask=None): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask=attn_mask) + if self.downsample is not None: + x, _ = self.downsample(x, self.input_resolution) + return x + + def extra_repr(self) -> str: + return f'dim={self.embed_dims}, \ + input_resolution={self.input_resolution}, depth={self.depth}' + + +@MODELS.register_module() +class MixMIMTransformer(BaseBackbone): + """MixMIM backbone. + + A PyTorch implement of : ` MixMIM: Mixed and Masked Image + Modeling for Efficient Visual Representation Learning + `_ + + Args: + arch (str | dict): MixMIM architecture. If use string, + choose from 'base','large' and 'huge'. + If use dict, it should have below keys: + + - **embed_dims** (int): The dimensions of embedding. + - **depths** (int): The number of transformer encoder layers. + - **num_heads** (int): The number of heads in attention modules. + + Defaults to 'base'. + mlp_ratio (int): The mlp ratio in FFN. Defaults to 4. + img_size (int | tuple): The expected input image shape. Because we + support dynamic input shape, just set the argument to mlp_ratio + the most common input image shape. Defaults to 224. + patch_size (int | tuple): The patch size in patch embedding. + Defaults to 16. + in_channels (int): The num of input channels. Defaults to 3. + window_size (list): The height and width of the window. + qkv_bias (bool): Whether to add bias for qkv in attention modules. + Defaults to True. + patch_cfg (dict): Extra config dict for patch embedding. + Defaults to an empty dict. + norm_cfg (dict): Config dict for normalization layer. + Defaults to ``dict(type='LN')``. + drop_rate (float): Probability of an element to be zeroed. + Defaults to 0. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + attn_drop_rate (float): attention drop rate. Defaults to 0. + use_checkpoint (bool): Whether use the checkpoint to + reduce GPU memory cost. + init_cfg (dict, optional): Initialization config dict. + Defaults to None. + """ + arch_zoo = { + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 128, + 'depths': [2, 2, 18, 2], + 'num_heads': [4, 8, 16, 32] + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 192, + 'depths': [2, 2, 18, 2], + 'num_heads': [6, 12, 24, 48] + }), + **dict.fromkeys( + ['h', 'huge'], { + 'embed_dims': 352, + 'depths': [2, 2, 18, 2], + 'num_heads': [11, 22, 44, 88] + }), + } + + def __init__( + self, + arch='base', + mlp_ratio=4, + img_size=224, + patch_size=4, + in_channels=3, + window_size=[14, 14, 14, 7], + qkv_bias=True, + patch_cfg=dict(), + norm_cfg=dict(type='LN'), + drop_rate=0.0, + drop_path_rate=0.0, + attn_drop_rate=0.0, + use_checkpoint=False, + init_cfg: Optional[dict] = None, + ) -> None: + super(MixMIMTransformer, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = { + 'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels' + } + assert isinstance(arch, dict) and essential_keys <= set(arch), \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.num_heads = self.arch_settings['num_heads'] + + self.encoder_stride = 32 + + self.num_layers = len(self.depths) + self.qkv_bias = qkv_bias + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.use_checkpoint = use_checkpoint + self.mlp_ratio = mlp_ratio + self.window_size = window_size + + _patch_cfg = dict( + in_channels=in_channels, + input_size=img_size, + embed_dims=self.embed_dims, + conv_type='Conv2d', + kernel_size=patch_size, + stride=patch_size, + norm_cfg=dict(type='LN'), + ) + _patch_cfg.update(patch_cfg) + self.patch_embed = PatchEmbed(**_patch_cfg) + self.patch_resolution = self.patch_embed.init_out_size + + self.dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + self.layers.append( + MixMIMLayer( + embed_dims=int(self.embed_dims * 2**i_layer), + input_resolution=(self.patch_resolution[0] // (2**i_layer), + self.patch_resolution[1] // + (2**i_layer)), + depth=self.depths[i_layer], + num_heads=self.num_heads[i_layer], + window_size=self.window_size[i_layer], + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + proj_drop_rate=self.drop_rate, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=self.dpr[sum(self.depths[:i_layer] + ):sum(self.depths[:i_layer + + 1])], + norm_cfg=norm_cfg, + downsample=PatchMerging if + (i_layer < self.num_layers - 1) else None, + use_checkpoint=self.use_checkpoint)) + + self.num_features = int(self.embed_dims * 2**(self.num_layers - 1)) + self.drop_after_pos = nn.Dropout(p=self.drop_rate) + + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.num_patches = self.patch_resolution[0] * self.patch_resolution[1] + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches, self.embed_dims), + requires_grad=False) + + _, self.norm = build_norm_layer(norm_cfg, self.num_features) + + def forward(self, x: torch.Tensor): + x, _ = self.patch_embed(x) + + x = x + self.absolute_pos_embed + x = self.drop_after_pos(x) + + for layer in self.layers: + x = layer(x, attn_mask=None) + + x = self.norm(x) + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + + return (x, ) diff --git a/model-index.yml b/model-index.yml index d0bbf424..a761ab8a 100644 --- a/model-index.yml +++ b/model-index.yml @@ -45,3 +45,4 @@ Import: - configs/beitv2/metafile.yml - configs/eva/metafile.yml - configs/revvit/metafile.yml + - configs/mixmim/metafile.yml diff --git a/tests/test_models/test_backbones/test_mixmim.py b/tests/test_models/test_backbones/test_mixmim.py new file mode 100644 index 00000000..e21d143c --- /dev/null +++ b/tests/test_models/test_backbones/test_mixmim.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from unittest import TestCase + +import torch + +from mmcls.models.backbones import MixMIMTransformer + + +class TestMixMIM(TestCase): + + def setUp(self): + self.cfg = dict(arch='b', drop_rate=0.0, drop_path_rate=0.1) + + def test_structure(self): + + # Test custom arch + cfg = deepcopy(self.cfg) + + model = MixMIMTransformer(**cfg) + self.assertEqual(model.embed_dims, 128) + self.assertEqual(sum(model.depths), 24) + self.assertIsNotNone(model.absolute_pos_embed) + + num_heads = [4, 8, 16, 32] + for i, layer in enumerate(model.layers): + self.assertEqual(layer.blocks[0].num_heads, num_heads[i]) + self.assertEqual(layer.blocks[0].ffn.feedforward_channels, + 128 * (2**i) * 4) + + def test_forward(self): + imgs = torch.randn(1, 3, 224, 224) + + cfg = deepcopy(self.cfg) + model = MixMIMTransformer(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + averaged_token = outs[-1] + self.assertEqual(averaged_token.shape, (1, 1024)) diff --git a/tools/model_converters/mixmimx_to_mmcls.py b/tools/model_converters/mixmimx_to_mmcls.py new file mode 100644 index 00000000..dcf9858b --- /dev/null +++ b/tools/model_converters/mixmimx_to_mmcls.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmengine +import torch +from mmengine.runner import CheckpointLoader + + +def correct_unfold_reduction_order(x: torch.Tensor): + out_channel, in_channel = x.shape + x = x.reshape(out_channel, 4, in_channel // 4) + x = x[:, [0, 2, 1, 3], :].transpose(1, 2).reshape(out_channel, in_channel) + return x + + +def correct_unfold_norm_order(x): + in_channel = x.shape[0] + x = x.reshape(4, in_channel // 4) + x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) + return x + + +def convert_mixmim(ckpt): + + new_ckpt = OrderedDict() + + for k, v in list(ckpt.items()): + new_v = v + + if k.startswith('patch_embed'): + new_k = k.replace('proj', 'projection') + + elif k.startswith('layers'): + if 'norm1' in k: + new_k = k.replace('norm1', 'ln1') + elif 'norm2' in k: + new_k = k.replace('norm2', 'ln2') + elif 'mlp.fc1' in k: + new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') + elif 'mlp.fc2' in k: + new_k = k.replace('mlp.fc2', 'ffn.layers.1') + else: + new_k = k + + elif k.startswith('norm') or k.startswith('absolute_pos_embed'): + new_k = k + + elif k.startswith('head'): + new_k = k.replace('head.', 'head.fc.') + + else: + raise ValueError + + # print(new_k) + if not new_k.startswith('head'): + new_k = 'backbone.' + new_k + + if 'downsample' in new_k: + print('Covert {} in PatchMerging from timm to mmcv format!'.format( + new_k)) + + if 'reduction' in new_k: + new_v = correct_unfold_reduction_order(new_v) + elif 'norm' in new_k: + new_v = correct_unfold_norm_order(new_v) + + new_ckpt[new_k] = new_v + + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in pretrained van models to mmcls style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + + if 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + + weight = convert_mixmim(state_dict) + # weight = convert_official_mixmim(state_dict) + mmengine.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + print('Done!!') + + +if __name__ == '__main__': + main()