diff --git a/configs/_base_/models/upernet_mae.py b/configs/_base_/models/upernet_mae.py new file mode 100644 index 000000000..1e0da7082 --- /dev/null +++ b/configs/_base_/models/upernet_mae.py @@ -0,0 +1,49 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='MAE', + img_size=(640, 640), + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=(3, 5, 7, 11), + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_cfg=dict(type='LN', eps=1e-6), + act_cfg=dict(type='GELU'), + norm_eval=False, + init_values=0.1), + neck=dict(type='Feature2Pyramid', embed_dim=768, rescales=[4, 2, 1, 0.5]), + decode_head=dict( + type='UPerHead', + in_channels=[384, 384, 384, 384], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=384, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/configs/mae/README.md b/configs/mae/README.md new file mode 100644 index 000000000..f42ff0a71 --- /dev/null +++ b/configs/mae/README.md @@ -0,0 +1,81 @@ +# MAE + +[Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) + +## Introduction + + + +Official Repo + +Code Snippet + +## Abstract + + + +This paper shows that masked autoencoders (MAE) are scalable self-supervised learners for computer vision. Our MAE approach is simple: we mask random patches of the input image and reconstruct the missing pixels. It is based on two core designs. First, we develop an asymmetric encoder-decoder architecture, with an encoder that operates only on the visible subset of patches (without mask tokens), along with a lightweight decoder that reconstructs the original image from the latent representation and mask tokens. Second, we find that masking a high proportion of the input image, e.g., 75%, yields a nontrivial and meaningful self-supervisory task. Coupling these two designs enables us to train large models efficiently and effectively: we accelerate training (by 3x or more) and improve accuracy. Our scalable approach allows for learning high-capacity models that generalize well: e.g., a vanilla ViT-Huge model achieves the best accuracy (87.8%) among methods that use only ImageNet-1K data. Transfer performance in downstream tasks outperforms supervised pre-training and shows promising scaling behavior. + + +
+ +
+ +## Citation + +```bibtex +@article{he2021masked, + title={Masked autoencoders are scalable vision learners}, + author={He, Kaiming and Chen, Xinlei and Xie, Saining and Li, Yanghao and Doll{\'a}r, Piotr and Girshick, Ross}, + journal={arXiv preprint arXiv:2111.06377}, + year={2021} +} +``` + +## Usage + +To use other repositories' pre-trained models, it is necessary to convert keys. + +We provide a script [`beit2mmseg.py`](../../tools/model_converters/beit2mmseg.py) in the tools directory to convert the key of MAE model from [the official repo](https://github.com/facebookresearch/mae) to MMSegmentation style. + +```shell +python tools/model_converters/beit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH} +``` + +E.g. + +```shell +python tools/model_converters/beit2mmseg.py https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth pretrain/mae_pretrain_vit_base_mmcls.pth +``` + +This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`. + +In our default setting, pretrained models could be defined below: + + | pretrained models | original models | + | ------ | -------- | + |mae_pretrain_vit_base_mmcls.pth | ['mae_pretrain_vit_base'](https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth) | + +Verify the single-scale results of the model: + +```shell +sh tools/dist_test.sh \ +configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py \ +upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth $GPUS --eval mIoU +``` + +Since relative position embedding requires the input length and width to be equal, the sliding window is adopted for multi-scale inference. So we set min_size=512, that is, the shortest edge is 512. So the multi-scale inference of config is performed separately, instead of '--aug-test'. For multi-scale inference: + +```shell +sh tools/dist_test.sh \ +configs/mae/upernet_mae-base_fp16_512x512_160k_ade20k_ms.py \ +upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth $GPUS --eval mIoU +``` + +## Results and models + +### ADE20K + +| Method | Backbone | Crop Size | pretrain | pretrain img size | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- | ------------: | -------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| UperNet | ViT-B | 512x512 | ImageNet-1K | 224x224 | 16 | 160000 | 9.96 | 7.14 | 48.13 | 48.70 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k/upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k/upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752.log.json) | diff --git a/configs/mae/mae.yml b/configs/mae/mae.yml new file mode 100644 index 000000000..5a869344e --- /dev/null +++ b/configs/mae/mae.yml @@ -0,0 +1,23 @@ +Models: +- Name: upernet_mae-base_fp16_8x2_512x512_160k_ade20k + In Collection: UperNet + Metadata: + backbone: ViT-B + crop size: (512,512) + lr schd: 160000 + inference time (ms/im): + - value: 140.06 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP16 + resolution: (512,512) + Training Memory (GB): 9.96 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 48.13 + mIoU(ms+flip): 48.7 + Config: configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k/upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth diff --git a/configs/mae/upernet_mae-base_fp16_512x512_160k_ade20k_ms.py b/configs/mae/upernet_mae-base_fp16_512x512_160k_ade20k_ms.py new file mode 100644 index 000000000..85b3be303 --- /dev/null +++ b/configs/mae/upernet_mae-base_fp16_512x512_160k_ade20k_ms.py @@ -0,0 +1,24 @@ +_base_ = './upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py' + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 512), + img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=True, + transforms=[ + dict(type='Resize', keep_ratio=True, min_size=512), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline), + samples_per_gpu=2) diff --git a/configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py b/configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py new file mode 100644 index 000000000..cb236cc04 --- /dev/null +++ b/configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py @@ -0,0 +1,48 @@ +_base_ = [ + '../_base_/models/upernet_mae.py', '../_base_/datasets/ade20k.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' +] + +model = dict( + pretrained='./pretrain/mae_pretrain_vit_base_mmcls.pth', + backbone=dict( + type='MAE', + img_size=(512, 512), + patch_size=16, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + init_values=1.0, + drop_path_rate=0.1, + out_indices=[3, 5, 7, 11]), + neck=dict(embed_dim=768, rescales=[4, 2, 1, 0.5]), + decode_head=dict( + in_channels=[768, 768, 768, 768], num_classes=150, channels=768), + auxiliary_head=dict(in_channels=768, num_classes=150), + test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(341, 341))) + +optimizer = dict( + _delete_=True, + type='AdamW', + lr=1e-4, + betas=(0.9, 0.999), + weight_decay=0.05, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65)) + +lr_config = dict( + _delete_=True, + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, + min_lr=0.0, + by_epoch=False) + +# mixed precision +fp16 = dict(loss_scale='dynamic') + +# By default, models are trained on 8 GPUs with 2 images per GPU +data = dict(samples_per_gpu=2) diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py index 1ede4874d..bda42bb69 100644 --- a/mmseg/models/backbones/__init__.py +++ b/mmseg/models/backbones/__init__.py @@ -7,6 +7,7 @@ from .erfnet import ERFNet from .fast_scnn import FastSCNN from .hrnet import HRNet from .icnet import ICNet +from .mae import MAE from .mit import MixVisionTransformer from .mobilenet_v2 import MobileNetV2 from .mobilenet_v3 import MobileNetV3 @@ -25,5 +26,5 @@ __all__ = [ 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', - 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT' + 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE' ] diff --git a/mmseg/models/backbones/mae.py b/mmseg/models/backbones/mae.py new file mode 100644 index 000000000..d3e8754bd --- /dev/null +++ b/mmseg/models/backbones/mae.py @@ -0,0 +1,261 @@ +# Copyright (c) OpenMMLab. All rights reserved.import math +import math + +import torch +import torch.nn as nn +from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init, + trunc_normal_) +from mmcv.runner import ModuleList, _load_checkpoint +from torch.nn.modules.batchnorm import _BatchNorm + +from mmseg.utils import get_root_logger +from ..builder import BACKBONES +from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer + + +class MAEAttention(BEiTAttention): + """Multi-head self-attention with relative position bias used in MAE. + + This module is different from ``BEiTAttention`` by initializing the + relative bias table with zeros. + """ + + def init_weights(self): + """Initialize relative position bias with zeros.""" + + # As MAE initializes relative position bias as zeros and this class + # inherited from BEiT which initializes relative position bias + # with `trunc_normal`, `init_weights` here does + # nothing and just passes directly + + pass + + +class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer): + """Implements one encoder layer in Vision Transformer. + + This module is different from ``BEiTTransformerEncoderLayer`` by replacing + ``BEiTAttention`` with ``MAEAttention``. + """ + + def build_attn(self, attn_cfg): + self.attn = MAEAttention(**attn_cfg) + + +@BACKBONES.register_module() +class MAE(BEiT): + """VisionTransformer with support for patch. + + Args: + img_size (int | tuple): Input image size. Default: 224. + patch_size (int): The patch size. Default: 16. + in_channels (int): Number of input channels. Default: 3. + embed_dims (int): embedding dimension. Default: 768. + num_layers (int): depth of transformer. Default: 12. + num_heads (int): number of attention heads. Default: 12. + mlp_ratio (int): ratio of mlp hidden dim to embedding dim. + Default: 4. + out_indices (list | tuple | int): Output from which stages. + Default: -1. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0 + drop_path_rate (float): stochastic depth rate. Default 0.0. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN') + act_cfg (dict): The activation config for FFNs. + Default: dict(type='GELU'). + patch_norm (bool): Whether to add a norm in PatchEmbed Block. + Default: False. + final_norm (bool): Whether to add a additional layer to normalize + final feature map. Default: False. + num_fcs (int): The number of fully-connected layers for FFNs. + Default: 2. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + pretrained (str, optional): model pretrained path. Default: None. + init_values (float): Initialize the values of Attention and FFN + with learnable scaling. Defaults to 0.1. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=-1, + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN'), + act_cfg=dict(type='GELU'), + patch_norm=False, + final_norm=False, + num_fcs=2, + norm_eval=False, + pretrained=None, + init_values=0.1, + init_cfg=None): + super(MAE, self).__init__( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims, + num_layers=num_layers, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + out_indices=out_indices, + qv_bias=False, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + patch_norm=patch_norm, + final_norm=final_norm, + num_fcs=num_fcs, + norm_eval=norm_eval, + pretrained=pretrained, + init_values=init_values, + init_cfg=init_cfg) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) + + self.num_patches = self.patch_shape[0] * self.patch_shape[1] + self.pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches + 1, embed_dims)) + + def _build_layers(self): + dpr = [ + x.item() + for x in torch.linspace(0, self.drop_path_rate, self.num_layers) + ] + self.layers = ModuleList() + for i in range(self.num_layers): + self.layers.append( + MAETransformerEncoderLayer( + embed_dims=self.embed_dims, + num_heads=self.num_heads, + feedforward_channels=self.mlp_ratio * self.embed_dims, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=dpr[i], + num_fcs=self.num_fcs, + bias=True, + act_cfg=self.act_cfg, + norm_cfg=self.norm_cfg, + window_size=self.patch_shape, + init_values=self.init_values)) + + def fix_init_weight(self): + """Rescale the initialization according to layer id. + + This function is copied from https://github.com/microsoft/unilm/blob/master/beit/modeling_pretrain.py. # noqa: E501 + Copyright (c) Microsoft Corporation + Licensed under the MIT License + """ + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.layers): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.ffn.layers[1].weight.data, layer_id + 1) + + def init_weights(self): + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + self.fix_init_weight() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg.get('type') == 'Pretrained'): + logger = get_root_logger() + checkpoint = _load_checkpoint( + self.init_cfg['checkpoint'], logger=logger, map_location='cpu') + state_dict = self.resize_rel_pos_embed(checkpoint) + state_dict = self.resize_abs_pos_embed(state_dict) + self.load_state_dict(state_dict, False) + elif self.init_cfg is not None: + super(MAE, self).init_weights() + else: + # We only implement the 'jax_impl' initialization implemented at + # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 + # Copyright 2019 Ross Wightman + # Licensed under the Apache License, Version 2.0 (the "License") + trunc_normal_(self.cls_token, std=.02) + for n, m in self.named_modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + if 'ffn' in n: + nn.init.normal_(m.bias, mean=0., std=1e-6) + else: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + kaiming_init(m, mode='fan_in', bias=0.) + elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): + constant_init(m, val=1.0, bias=0.) + + def resize_abs_pos_embed(self, state_dict): + if 'pos_embed' in state_dict: + pos_embed_checkpoint = state_dict['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches + # height (== width) for the checkpoint position embedding + orig_size = int( + (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) + # height (== width) for the new position embedding + new_size = int(self.num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, + embedding_size).permute( + 0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode='bicubic', + align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict['pos_embed'] = new_pos_embed + return state_dict + + def forward(self, inputs): + B = inputs.shape[0] + + x, hw_shape = self.patch_embed(inputs) + + # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i == len(self.layers) - 1: + if self.final_norm: + x = self.norm1(x) + if i in self.out_indices: + out = x[:, 1:] + B, _, C = out.shape + out = out.reshape(B, hw_shape[0], hw_shape[1], + C).permute(0, 3, 1, 2).contiguous() + outs.append(out) + + return tuple(outs) diff --git a/model-index.yml b/model-index.yml index d8e9516bf..2053fd049 100644 --- a/model-index.yml +++ b/model-index.yml @@ -24,6 +24,7 @@ Import: - configs/icnet/icnet.yml - configs/isanet/isanet.yml - configs/knet/knet.yml +- configs/mae/mae.yml - configs/mobilenet_v2/mobilenet_v2.yml - configs/mobilenet_v3/mobilenet_v3.yml - configs/nonlocal_net/nonlocal_net.yml diff --git a/tests/test_models/test_backbones/test_mae.py b/tests/test_models/test_backbones/test_mae.py new file mode 100644 index 000000000..562d067a7 --- /dev/null +++ b/tests/test_models/test_backbones/test_mae.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmseg.models.backbones.mae import MAE +from .utils import check_norm_state + + +def test_mae_backbone(): + with pytest.raises(TypeError): + # pretrained must be a string path + model = MAE() + model.init_weights(pretrained=0) + + with pytest.raises(TypeError): + # img_size must be int or tuple + model = MAE(img_size=512.0) + + with pytest.raises(TypeError): + # out_indices must be int ,list or tuple + model = MAE(out_indices=1.) + + with pytest.raises(AssertionError): + # The length of img_size tuple must be lower than 3. + MAE(img_size=(224, 224, 224)) + + with pytest.raises(TypeError): + # Pretrained must be None or Str. + MAE(pretrained=123) + + # Test img_size isinstance tuple + imgs = torch.randn(1, 3, 224, 224) + model = MAE(img_size=(224, )) + model.init_weights() + model(imgs) + + # Test img_size isinstance tuple + imgs = torch.randn(1, 3, 224, 224) + model = MAE(img_size=(224, 224)) + model(imgs) + + # Test norm_eval = True + model = MAE(norm_eval=True) + model.train() + + # Test BEiT backbone with input size of 224 and patch size of 16 + model = MAE() + model.init_weights() + model.train() + + # Test out_indices = list + model = MAE(out_indices=[2, 4, 8, 12]) + model.train() + + assert check_norm_state(model.modules(), True) + + # Test image size = (224, 224) + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat[-1].shape == (1, 768, 14, 14) + + # Test MAE backbone with input size of 256 and patch size of 16 + model = MAE(img_size=(256, 256)) + model.init_weights() + model.train() + imgs = torch.randn(1, 3, 256, 256) + feat = model(imgs) + assert feat[-1].shape == (1, 768, 16, 16) + + # Test MAE backbone with input size of 32 and patch size of 16 + model = MAE(img_size=(32, 32)) + model.init_weights() + model.train() + imgs = torch.randn(1, 3, 32, 32) + feat = model(imgs) + assert feat[-1].shape == (1, 768, 2, 2) + + # Test unbalanced size input image + model = MAE(img_size=(112, 224)) + model.init_weights() + model.train() + imgs = torch.randn(1, 3, 112, 224) + feat = model(imgs) + assert feat[-1].shape == (1, 768, 7, 14) + + # Test irregular input image + model = MAE(img_size=(234, 345)) + model.init_weights() + model.train() + imgs = torch.randn(1, 3, 234, 345) + feat = model(imgs) + assert feat[-1].shape == (1, 768, 14, 21) + + # Test init_values=0 + model = MAE(init_values=0) + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat[-1].shape == (1, 768, 14, 14) + + # Test final norm + model = MAE(final_norm=True) + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat[-1].shape == (1, 768, 14, 14) + + # Test patch norm + model = MAE(patch_norm=True) + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat[-1].shape == (1, 768, 14, 14) + + +def test_mae_init(): + path = 'PATH_THAT_DO_NOT_EXIST' + # Test all combinations of pretrained and init_cfg + # pretrained=None, init_cfg=None + model = MAE(pretrained=None, init_cfg=None) + assert model.init_cfg is None + model.init_weights() + + # pretrained=None + # init_cfg loads pretrain from an non-existent file + model = MAE( + pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path)) + assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + # Test loading a checkpoint from an non-existent file + with pytest.raises(OSError): + model.init_weights() + + # test resize_rel_pos_embed + value = torch.randn(732, 16) + abs_pos_embed_value = torch.rand(1, 17, 768) + ckpt = { + 'state_dict': { + 'layers.0.attn.relative_position_index': 0, + 'layers.0.attn.relative_position_bias_table': value, + 'pos_embed': abs_pos_embed_value + } + } + model = MAE(img_size=(512, 512)) + with pytest.raises(AttributeError): + model.resize_rel_pos_embed(ckpt) + + # test resize abs pos embed + ckpt = model.resize_abs_pos_embed(ckpt['state_dict']) + + # pretrained=None + # init_cfg=123, whose type is unsupported + model = MAE(pretrained=None, init_cfg=123) + with pytest.raises(TypeError): + model.init_weights() + + # pretrained loads pretrain from an non-existent file + # init_cfg=None + model = MAE(pretrained=path, init_cfg=None) + assert model.init_cfg == dict(type='Pretrained', checkpoint=path) + # Test loading a checkpoint from an non-existent file + with pytest.raises(OSError): + model.init_weights() + + # pretrained loads pretrain from an non-existent file + # init_cfg loads pretrain from an non-existent file + with pytest.raises(AssertionError): + model = MAE( + pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path)) + with pytest.raises(AssertionError): + model = MAE(pretrained=path, init_cfg=123) + + # pretrain=123, whose type is unsupported + # init_cfg=None + with pytest.raises(TypeError): + model = MAE(pretrained=123, init_cfg=None) + + # pretrain=123, whose type is unsupported + # init_cfg loads pretrain from an non-existent file + with pytest.raises(AssertionError): + model = MAE( + pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path)) + + # pretrain=123, whose type is unsupported + # init_cfg=123, whose type is unsupported + with pytest.raises(AssertionError): + model = MAE(pretrained=123, init_cfg=123)