diff --git a/configs/_base_/models/upernet_mae.py b/configs/_base_/models/upernet_mae.py
new file mode 100644
index 000000000..1e0da7082
--- /dev/null
+++ b/configs/_base_/models/upernet_mae.py
@@ -0,0 +1,49 @@
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='MAE',
+ img_size=(640, 640),
+ patch_size=16,
+ in_channels=3,
+ embed_dims=768,
+ num_layers=12,
+ num_heads=12,
+ mlp_ratio=4,
+ out_indices=(3, 5, 7, 11),
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_cfg=dict(type='LN', eps=1e-6),
+ act_cfg=dict(type='GELU'),
+ norm_eval=False,
+ init_values=0.1),
+ neck=dict(type='Feature2Pyramid', embed_dim=768, rescales=[4, 2, 1, 0.5]),
+ decode_head=dict(
+ type='UPerHead',
+ in_channels=[384, 384, 384, 384],
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=512,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=384,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/configs/mae/README.md b/configs/mae/README.md
new file mode 100644
index 000000000..f42ff0a71
--- /dev/null
+++ b/configs/mae/README.md
@@ -0,0 +1,81 @@
+# MAE
+
+[Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+
+This paper shows that masked autoencoders (MAE) are scalable self-supervised learners for computer vision. Our MAE approach is simple: we mask random patches of the input image and reconstruct the missing pixels. It is based on two core designs. First, we develop an asymmetric encoder-decoder architecture, with an encoder that operates only on the visible subset of patches (without mask tokens), along with a lightweight decoder that reconstructs the original image from the latent representation and mask tokens. Second, we find that masking a high proportion of the input image, e.g., 75%, yields a nontrivial and meaningful self-supervisory task. Coupling these two designs enables us to train large models efficiently and effectively: we accelerate training (by 3x or more) and improve accuracy. Our scalable approach allows for learning high-capacity models that generalize well: e.g., a vanilla ViT-Huge model achieves the best accuracy (87.8%) among methods that use only ImageNet-1K data. Transfer performance in downstream tasks outperforms supervised pre-training and shows promising scaling behavior.
+
+
+
+

+
+
+## Citation
+
+```bibtex
+@article{he2021masked,
+ title={Masked autoencoders are scalable vision learners},
+ author={He, Kaiming and Chen, Xinlei and Xie, Saining and Li, Yanghao and Doll{\'a}r, Piotr and Girshick, Ross},
+ journal={arXiv preprint arXiv:2111.06377},
+ year={2021}
+}
+```
+
+## Usage
+
+To use other repositories' pre-trained models, it is necessary to convert keys.
+
+We provide a script [`beit2mmseg.py`](../../tools/model_converters/beit2mmseg.py) in the tools directory to convert the key of MAE model from [the official repo](https://github.com/facebookresearch/mae) to MMSegmentation style.
+
+```shell
+python tools/model_converters/beit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
+```
+
+E.g.
+
+```shell
+python tools/model_converters/beit2mmseg.py https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth pretrain/mae_pretrain_vit_base_mmcls.pth
+```
+
+This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
+
+In our default setting, pretrained models could be defined below:
+
+ | pretrained models | original models |
+ | ------ | -------- |
+ |mae_pretrain_vit_base_mmcls.pth | ['mae_pretrain_vit_base'](https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth) |
+
+Verify the single-scale results of the model:
+
+```shell
+sh tools/dist_test.sh \
+configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py \
+upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth $GPUS --eval mIoU
+```
+
+Since relative position embedding requires the input length and width to be equal, the sliding window is adopted for multi-scale inference. So we set min_size=512, that is, the shortest edge is 512. So the multi-scale inference of config is performed separately, instead of '--aug-test'. For multi-scale inference:
+
+```shell
+sh tools/dist_test.sh \
+configs/mae/upernet_mae-base_fp16_512x512_160k_ade20k_ms.py \
+upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth $GPUS --eval mIoU
+```
+
+## Results and models
+
+### ADE20K
+
+| Method | Backbone | Crop Size | pretrain | pretrain img size | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
+| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- | ------------: | -------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
+| UperNet | ViT-B | 512x512 | ImageNet-1K | 224x224 | 16 | 160000 | 9.96 | 7.14 | 48.13 | 48.70 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k/upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k/upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752.log.json) |
diff --git a/configs/mae/mae.yml b/configs/mae/mae.yml
new file mode 100644
index 000000000..5a869344e
--- /dev/null
+++ b/configs/mae/mae.yml
@@ -0,0 +1,23 @@
+Models:
+- Name: upernet_mae-base_fp16_8x2_512x512_160k_ade20k
+ In Collection: UperNet
+ Metadata:
+ backbone: ViT-B
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 140.06
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP16
+ resolution: (512,512)
+ Training Memory (GB): 9.96
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 48.13
+ mIoU(ms+flip): 48.7
+ Config: configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k/upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth
diff --git a/configs/mae/upernet_mae-base_fp16_512x512_160k_ade20k_ms.py b/configs/mae/upernet_mae-base_fp16_512x512_160k_ade20k_ms.py
new file mode 100644
index 000000000..85b3be303
--- /dev/null
+++ b/configs/mae/upernet_mae-base_fp16_512x512_160k_ade20k_ms.py
@@ -0,0 +1,24 @@
+_base_ = './upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py'
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2048, 512),
+ img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=True,
+ transforms=[
+ dict(type='Resize', keep_ratio=True, min_size=512),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ val=dict(pipeline=test_pipeline),
+ test=dict(pipeline=test_pipeline),
+ samples_per_gpu=2)
diff --git a/configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py b/configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py
new file mode 100644
index 000000000..cb236cc04
--- /dev/null
+++ b/configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py
@@ -0,0 +1,48 @@
+_base_ = [
+ '../_base_/models/upernet_mae.py', '../_base_/datasets/ade20k.py',
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
+]
+
+model = dict(
+ pretrained='./pretrain/mae_pretrain_vit_base_mmcls.pth',
+ backbone=dict(
+ type='MAE',
+ img_size=(512, 512),
+ patch_size=16,
+ embed_dims=768,
+ num_layers=12,
+ num_heads=12,
+ mlp_ratio=4,
+ init_values=1.0,
+ drop_path_rate=0.1,
+ out_indices=[3, 5, 7, 11]),
+ neck=dict(embed_dim=768, rescales=[4, 2, 1, 0.5]),
+ decode_head=dict(
+ in_channels=[768, 768, 768, 768], num_classes=150, channels=768),
+ auxiliary_head=dict(in_channels=768, num_classes=150),
+ test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(341, 341)))
+
+optimizer = dict(
+ _delete_=True,
+ type='AdamW',
+ lr=1e-4,
+ betas=(0.9, 0.999),
+ weight_decay=0.05,
+ constructor='LayerDecayOptimizerConstructor',
+ paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65))
+
+lr_config = dict(
+ _delete_=True,
+ policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0,
+ min_lr=0.0,
+ by_epoch=False)
+
+# mixed precision
+fp16 = dict(loss_scale='dynamic')
+
+# By default, models are trained on 8 GPUs with 2 images per GPU
+data = dict(samples_per_gpu=2)
diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py
index 1ede4874d..bda42bb69 100644
--- a/mmseg/models/backbones/__init__.py
+++ b/mmseg/models/backbones/__init__.py
@@ -7,6 +7,7 @@ from .erfnet import ERFNet
from .fast_scnn import FastSCNN
from .hrnet import HRNet
from .icnet import ICNet
+from .mae import MAE
from .mit import MixVisionTransformer
from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3
@@ -25,5 +26,5 @@ __all__ = [
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
- 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT'
+ 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE'
]
diff --git a/mmseg/models/backbones/mae.py b/mmseg/models/backbones/mae.py
new file mode 100644
index 000000000..d3e8754bd
--- /dev/null
+++ b/mmseg/models/backbones/mae.py
@@ -0,0 +1,261 @@
+# Copyright (c) OpenMMLab. All rights reserved.import math
+import math
+
+import torch
+import torch.nn as nn
+from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
+ trunc_normal_)
+from mmcv.runner import ModuleList, _load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer
+
+
+class MAEAttention(BEiTAttention):
+ """Multi-head self-attention with relative position bias used in MAE.
+
+ This module is different from ``BEiTAttention`` by initializing the
+ relative bias table with zeros.
+ """
+
+ def init_weights(self):
+ """Initialize relative position bias with zeros."""
+
+ # As MAE initializes relative position bias as zeros and this class
+ # inherited from BEiT which initializes relative position bias
+ # with `trunc_normal`, `init_weights` here does
+ # nothing and just passes directly
+
+ pass
+
+
+class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer):
+ """Implements one encoder layer in Vision Transformer.
+
+ This module is different from ``BEiTTransformerEncoderLayer`` by replacing
+ ``BEiTAttention`` with ``MAEAttention``.
+ """
+
+ def build_attn(self, attn_cfg):
+ self.attn = MAEAttention(**attn_cfg)
+
+
+@BACKBONES.register_module()
+class MAE(BEiT):
+ """VisionTransformer with support for patch.
+
+ Args:
+ img_size (int | tuple): Input image size. Default: 224.
+ patch_size (int): The patch size. Default: 16.
+ in_channels (int): Number of input channels. Default: 3.
+ embed_dims (int): embedding dimension. Default: 768.
+ num_layers (int): depth of transformer. Default: 12.
+ num_heads (int): number of attention heads. Default: 12.
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
+ Default: 4.
+ out_indices (list | tuple | int): Output from which stages.
+ Default: -1.
+ attn_drop_rate (float): The drop out rate for attention layer.
+ Default 0.0
+ drop_path_rate (float): stochastic depth rate. Default 0.0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN')
+ act_cfg (dict): The activation config for FFNs.
+ Default: dict(type='GELU').
+ patch_norm (bool): Whether to add a norm in PatchEmbed Block.
+ Default: False.
+ final_norm (bool): Whether to add a additional layer to normalize
+ final feature map. Default: False.
+ num_fcs (int): The number of fully-connected layers for FFNs.
+ Default: 2.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ pretrained (str, optional): model pretrained path. Default: None.
+ init_values (float): Initialize the values of Attention and FFN
+ with learnable scaling. Defaults to 0.1.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ img_size=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dims=768,
+ num_layers=12,
+ num_heads=12,
+ mlp_ratio=4,
+ out_indices=-1,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_cfg=dict(type='LN'),
+ act_cfg=dict(type='GELU'),
+ patch_norm=False,
+ final_norm=False,
+ num_fcs=2,
+ norm_eval=False,
+ pretrained=None,
+ init_values=0.1,
+ init_cfg=None):
+ super(MAE, self).__init__(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dims=embed_dims,
+ num_layers=num_layers,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ out_indices=out_indices,
+ qv_bias=False,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=drop_path_rate,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ patch_norm=patch_norm,
+ final_norm=final_norm,
+ num_fcs=num_fcs,
+ norm_eval=norm_eval,
+ pretrained=pretrained,
+ init_values=init_values,
+ init_cfg=init_cfg)
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
+
+ self.num_patches = self.patch_shape[0] * self.patch_shape[1]
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, self.num_patches + 1, embed_dims))
+
+ def _build_layers(self):
+ dpr = [
+ x.item()
+ for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
+ ]
+ self.layers = ModuleList()
+ for i in range(self.num_layers):
+ self.layers.append(
+ MAETransformerEncoderLayer(
+ embed_dims=self.embed_dims,
+ num_heads=self.num_heads,
+ feedforward_channels=self.mlp_ratio * self.embed_dims,
+ attn_drop_rate=self.attn_drop_rate,
+ drop_path_rate=dpr[i],
+ num_fcs=self.num_fcs,
+ bias=True,
+ act_cfg=self.act_cfg,
+ norm_cfg=self.norm_cfg,
+ window_size=self.patch_shape,
+ init_values=self.init_values))
+
+ def fix_init_weight(self):
+ """Rescale the initialization according to layer id.
+
+ This function is copied from https://github.com/microsoft/unilm/blob/master/beit/modeling_pretrain.py. # noqa: E501
+ Copyright (c) Microsoft Corporation
+ Licensed under the MIT License
+ """
+
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.layers):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.ffn.layers[1].weight.data, layer_id + 1)
+
+ def init_weights(self):
+
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ self.apply(_init_weights)
+ self.fix_init_weight()
+
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg.get('type') == 'Pretrained'):
+ logger = get_root_logger()
+ checkpoint = _load_checkpoint(
+ self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
+ state_dict = self.resize_rel_pos_embed(checkpoint)
+ state_dict = self.resize_abs_pos_embed(state_dict)
+ self.load_state_dict(state_dict, False)
+ elif self.init_cfg is not None:
+ super(MAE, self).init_weights()
+ else:
+ # We only implement the 'jax_impl' initialization implemented at
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
+ # Copyright 2019 Ross Wightman
+ # Licensed under the Apache License, Version 2.0 (the "License")
+ trunc_normal_(self.cls_token, std=.02)
+ for n, m in self.named_modules():
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ if 'ffn' in n:
+ nn.init.normal_(m.bias, mean=0., std=1e-6)
+ else:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ kaiming_init(m, mode='fan_in', bias=0.)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
+ constant_init(m, val=1.0, bias=0.)
+
+ def resize_abs_pos_embed(self, state_dict):
+ if 'pos_embed' in state_dict:
+ pos_embed_checkpoint = state_dict['pos_embed']
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int(
+ (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
+ # height (== width) for the new position embedding
+ new_size = int(self.num_patches**0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
+ embedding_size).permute(
+ 0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens,
+ size=(new_size, new_size),
+ mode='bicubic',
+ align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ state_dict['pos_embed'] = new_pos_embed
+ return state_dict
+
+ def forward(self, inputs):
+ B = inputs.shape[0]
+
+ x, hw_shape = self.patch_embed(inputs)
+
+ # stole cls_tokens impl from Phil Wang, thanks
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = x + self.pos_embed
+
+ outs = []
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+ if i == len(self.layers) - 1:
+ if self.final_norm:
+ x = self.norm1(x)
+ if i in self.out_indices:
+ out = x[:, 1:]
+ B, _, C = out.shape
+ out = out.reshape(B, hw_shape[0], hw_shape[1],
+ C).permute(0, 3, 1, 2).contiguous()
+ outs.append(out)
+
+ return tuple(outs)
diff --git a/model-index.yml b/model-index.yml
index d8e9516bf..2053fd049 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -24,6 +24,7 @@ Import:
- configs/icnet/icnet.yml
- configs/isanet/isanet.yml
- configs/knet/knet.yml
+- configs/mae/mae.yml
- configs/mobilenet_v2/mobilenet_v2.yml
- configs/mobilenet_v3/mobilenet_v3.yml
- configs/nonlocal_net/nonlocal_net.yml
diff --git a/tests/test_models/test_backbones/test_mae.py b/tests/test_models/test_backbones/test_mae.py
new file mode 100644
index 000000000..562d067a7
--- /dev/null
+++ b/tests/test_models/test_backbones/test_mae.py
@@ -0,0 +1,183 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pytest
+import torch
+
+from mmseg.models.backbones.mae import MAE
+from .utils import check_norm_state
+
+
+def test_mae_backbone():
+ with pytest.raises(TypeError):
+ # pretrained must be a string path
+ model = MAE()
+ model.init_weights(pretrained=0)
+
+ with pytest.raises(TypeError):
+ # img_size must be int or tuple
+ model = MAE(img_size=512.0)
+
+ with pytest.raises(TypeError):
+ # out_indices must be int ,list or tuple
+ model = MAE(out_indices=1.)
+
+ with pytest.raises(AssertionError):
+ # The length of img_size tuple must be lower than 3.
+ MAE(img_size=(224, 224, 224))
+
+ with pytest.raises(TypeError):
+ # Pretrained must be None or Str.
+ MAE(pretrained=123)
+
+ # Test img_size isinstance tuple
+ imgs = torch.randn(1, 3, 224, 224)
+ model = MAE(img_size=(224, ))
+ model.init_weights()
+ model(imgs)
+
+ # Test img_size isinstance tuple
+ imgs = torch.randn(1, 3, 224, 224)
+ model = MAE(img_size=(224, 224))
+ model(imgs)
+
+ # Test norm_eval = True
+ model = MAE(norm_eval=True)
+ model.train()
+
+ # Test BEiT backbone with input size of 224 and patch size of 16
+ model = MAE()
+ model.init_weights()
+ model.train()
+
+ # Test out_indices = list
+ model = MAE(out_indices=[2, 4, 8, 12])
+ model.train()
+
+ assert check_norm_state(model.modules(), True)
+
+ # Test image size = (224, 224)
+ imgs = torch.randn(1, 3, 224, 224)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 14, 14)
+
+ # Test MAE backbone with input size of 256 and patch size of 16
+ model = MAE(img_size=(256, 256))
+ model.init_weights()
+ model.train()
+ imgs = torch.randn(1, 3, 256, 256)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 16, 16)
+
+ # Test MAE backbone with input size of 32 and patch size of 16
+ model = MAE(img_size=(32, 32))
+ model.init_weights()
+ model.train()
+ imgs = torch.randn(1, 3, 32, 32)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 2, 2)
+
+ # Test unbalanced size input image
+ model = MAE(img_size=(112, 224))
+ model.init_weights()
+ model.train()
+ imgs = torch.randn(1, 3, 112, 224)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 7, 14)
+
+ # Test irregular input image
+ model = MAE(img_size=(234, 345))
+ model.init_weights()
+ model.train()
+ imgs = torch.randn(1, 3, 234, 345)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 14, 21)
+
+ # Test init_values=0
+ model = MAE(init_values=0)
+ imgs = torch.randn(1, 3, 224, 224)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 14, 14)
+
+ # Test final norm
+ model = MAE(final_norm=True)
+ imgs = torch.randn(1, 3, 224, 224)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 14, 14)
+
+ # Test patch norm
+ model = MAE(patch_norm=True)
+ imgs = torch.randn(1, 3, 224, 224)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 14, 14)
+
+
+def test_mae_init():
+ path = 'PATH_THAT_DO_NOT_EXIST'
+ # Test all combinations of pretrained and init_cfg
+ # pretrained=None, init_cfg=None
+ model = MAE(pretrained=None, init_cfg=None)
+ assert model.init_cfg is None
+ model.init_weights()
+
+ # pretrained=None
+ # init_cfg loads pretrain from an non-existent file
+ model = MAE(
+ pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
+ assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
+ # Test loading a checkpoint from an non-existent file
+ with pytest.raises(OSError):
+ model.init_weights()
+
+ # test resize_rel_pos_embed
+ value = torch.randn(732, 16)
+ abs_pos_embed_value = torch.rand(1, 17, 768)
+ ckpt = {
+ 'state_dict': {
+ 'layers.0.attn.relative_position_index': 0,
+ 'layers.0.attn.relative_position_bias_table': value,
+ 'pos_embed': abs_pos_embed_value
+ }
+ }
+ model = MAE(img_size=(512, 512))
+ with pytest.raises(AttributeError):
+ model.resize_rel_pos_embed(ckpt)
+
+ # test resize abs pos embed
+ ckpt = model.resize_abs_pos_embed(ckpt['state_dict'])
+
+ # pretrained=None
+ # init_cfg=123, whose type is unsupported
+ model = MAE(pretrained=None, init_cfg=123)
+ with pytest.raises(TypeError):
+ model.init_weights()
+
+ # pretrained loads pretrain from an non-existent file
+ # init_cfg=None
+ model = MAE(pretrained=path, init_cfg=None)
+ assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
+ # Test loading a checkpoint from an non-existent file
+ with pytest.raises(OSError):
+ model.init_weights()
+
+ # pretrained loads pretrain from an non-existent file
+ # init_cfg loads pretrain from an non-existent file
+ with pytest.raises(AssertionError):
+ model = MAE(
+ pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
+ with pytest.raises(AssertionError):
+ model = MAE(pretrained=path, init_cfg=123)
+
+ # pretrain=123, whose type is unsupported
+ # init_cfg=None
+ with pytest.raises(TypeError):
+ model = MAE(pretrained=123, init_cfg=None)
+
+ # pretrain=123, whose type is unsupported
+ # init_cfg loads pretrain from an non-existent file
+ with pytest.raises(AssertionError):
+ model = MAE(
+ pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
+
+ # pretrain=123, whose type is unsupported
+ # init_cfg=123, whose type is unsupported
+ with pytest.raises(AssertionError):
+ model = MAE(pretrained=123, init_cfg=123)