[Feature]: Add MAE (#1307)

* [Fix]: Fix lint

* [WIP]: Add mae seg config

* [Feature]: Add MAE seg

* [Fix]: Fix mae dataset img scale bug

* [Fix]: Fix lint

* [Feature]: Change mae config to mae_segmentation's config

* [Feature]: Add interpolate pe when loading

* [Fix]: Fix pos_embed not used bug

* [Fix]: Fix lint

* [Fix]: Init rel pos embed with zeros

* [Fix]: Fix lint

* [Fix]: Change the type name of backbone to MAE

* [Fix]: Delete ade20k_512x512.py

* [Fix]: Use mmseg provided ade20k.py

* [Fix]: Change 1 sample per gpu to 2 samples per gpu

* [Fix]: Fix conflict

* [Refactor]: Use the TransformerEncoderLayer of  BEiT

* [Feature]: Add UT

* [Fix]: Change the default value of qv bias to False

* [Fix]: Initialize relative pos table with zeros

* [Fix]: Delete redundant code in mae

* [Fix]: Fix lint

* [Fix]: Rename qkv_bias to qv_bias

* [Fix]: Add docstring to weight_init of MAEAttention

* [Refactor]: Delete qv_bias param

* [Fix]: Add reference to fix_init_weight

* [Fix]: Fix lint

* [Fix]: Delete extra crop size

* [Refactor]: Rename mae

* [Fix]: Set bias to True

* [Fix]: Delete redundant params

* [Fix]: Fix lint

* [Fix]: Fix UT

* [Fix]: Add resize abs pos embed

* [Fix]: Fix UT

* [Refactor]: Use build layer

* [Fix]: Add licsense and fix docstring

* [Fix]: Fix docstring

* [Feature]: Add README metafile

* [Fix]: Change 640 to 512

* [Fix]: Fix README

* fix readme of MAE

Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
This commit is contained in:
Yuan Liu 2022-04-28 00:54:20 +08:00 committed by GitHub
parent 69b28e0b59
commit 6563cb513e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 672 additions and 1 deletions

View File

@ -0,0 +1,49 @@
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='MAE',
img_size=(640, 640),
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=(3, 5, 7, 11),
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
norm_eval=False,
init_values=0.1),
neck=dict(type='Feature2Pyramid', embed_dim=768, rescales=[4, 2, 1, 0.5]),
decode_head=dict(
type='UPerHead',
in_channels=[384, 384, 384, 384],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=384,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

81
configs/mae/README.md Normal file
View File

@ -0,0 +1,81 @@
# MAE
[Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)
## Introduction
<!-- [BACKBONE] -->
<a href="https://github.com/facebookresearch/mae">Official Repo</a>
<a href="https://github.com/open-mmlab/mmsegmentation/blob/v0.24.0/mmseg/models/backbones/mae.py#46">Code Snippet</a>
## Abstract
<!-- [ABSTRACT] -->
This paper shows that masked autoencoders (MAE) are scalable self-supervised learners for computer vision. Our MAE approach is simple: we mask random patches of the input image and reconstruct the missing pixels. It is based on two core designs. First, we develop an asymmetric encoder-decoder architecture, with an encoder that operates only on the visible subset of patches (without mask tokens), along with a lightweight decoder that reconstructs the original image from the latent representation and mask tokens. Second, we find that masking a high proportion of the input image, e.g., 75%, yields a nontrivial and meaningful self-supervisory task. Coupling these two designs enables us to train large models efficiently and effectively: we accelerate training (by 3x or more) and improve accuracy. Our scalable approach allows for learning high-capacity models that generalize well: e.g., a vanilla ViT-Huge model achieves the best accuracy (87.8%) among methods that use only ImageNet-1K data. Transfer performance in downstream tasks outperforms supervised pre-training and shows promising scaling behavior.
<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/24582831/165456416-1cba54bf-b1b5-4bdf-ad86-d6390de7f342.png" width="70%"/>
</div>
## Citation
```bibtex
@article{he2021masked,
title={Masked autoencoders are scalable vision learners},
author={He, Kaiming and Chen, Xinlei and Xie, Saining and Li, Yanghao and Doll{\'a}r, Piotr and Girshick, Ross},
journal={arXiv preprint arXiv:2111.06377},
year={2021}
}
```
## Usage
To use other repositories' pre-trained models, it is necessary to convert keys.
We provide a script [`beit2mmseg.py`](../../tools/model_converters/beit2mmseg.py) in the tools directory to convert the key of MAE model from [the official repo](https://github.com/facebookresearch/mae) to MMSegmentation style.
```shell
python tools/model_converters/beit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
```
E.g.
```shell
python tools/model_converters/beit2mmseg.py https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth pretrain/mae_pretrain_vit_base_mmcls.pth
```
This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
In our default setting, pretrained models could be defined below:
| pretrained models | original models |
| ------ | -------- |
|mae_pretrain_vit_base_mmcls.pth | ['mae_pretrain_vit_base'](https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth) |
Verify the single-scale results of the model:
```shell
sh tools/dist_test.sh \
configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py \
upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth $GPUS --eval mIoU
```
Since relative position embedding requires the input length and width to be equal, the sliding window is adopted for multi-scale inference. So we set min_size=512, that is, the shortest edge is 512. So the multi-scale inference of config is performed separately, instead of '--aug-test'. For multi-scale inference:
```shell
sh tools/dist_test.sh \
configs/mae/upernet_mae-base_fp16_512x512_160k_ade20k_ms.py \
upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth $GPUS --eval mIoU
```
## Results and models
### ADE20K
| Method | Backbone | Crop Size | pretrain | pretrain img size | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- | ------------: | -------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| UperNet | ViT-B | 512x512 | ImageNet-1K | 224x224 | 16 | 160000 | 9.96 | 7.14 | 48.13 | 48.70 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k/upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k/upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752.log.json) |

23
configs/mae/mae.yml Normal file
View File

@ -0,0 +1,23 @@
Models:
- Name: upernet_mae-base_fp16_8x2_512x512_160k_ade20k
In Collection: UperNet
Metadata:
backbone: ViT-B
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 140.06
hardware: V100
backend: PyTorch
batch size: 1
mode: FP16
resolution: (512,512)
Training Memory (GB): 9.96
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 48.13
mIoU(ms+flip): 48.7
Config: configs/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/mae/upernet_mae-base_fp16_8x2_512x512_160k_ade20k/upernet_mae-base_fp16_8x2_512x512_160k_ade20k_20220426_174752-f92a2975.pth

View File

@ -0,0 +1,24 @@
_base_ = './upernet_mae-base_fp16_8x2_512x512_160k_ade20k.py'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=True,
transforms=[
dict(type='Resize', keep_ratio=True, min_size=512),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline),
samples_per_gpu=2)

View File

@ -0,0 +1,48 @@
_base_ = [
'../_base_/models/upernet_mae.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
model = dict(
pretrained='./pretrain/mae_pretrain_vit_base_mmcls.pth',
backbone=dict(
type='MAE',
img_size=(512, 512),
patch_size=16,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
init_values=1.0,
drop_path_rate=0.1,
out_indices=[3, 5, 7, 11]),
neck=dict(embed_dim=768, rescales=[4, 2, 1, 0.5]),
decode_head=dict(
in_channels=[768, 768, 768, 768], num_classes=150, channels=768),
auxiliary_head=dict(in_channels=768, num_classes=150),
test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(341, 341)))
optimizer = dict(
_delete_=True,
type='AdamW',
lr=1e-4,
betas=(0.9, 0.999),
weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.65))
lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)
# mixed precision
fp16 = dict(loss_scale='dynamic')
# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)

View File

@ -7,6 +7,7 @@ from .erfnet import ERFNet
from .fast_scnn import FastSCNN from .fast_scnn import FastSCNN
from .hrnet import HRNet from .hrnet import HRNet
from .icnet import ICNet from .icnet import ICNet
from .mae import MAE
from .mit import MixVisionTransformer from .mit import MixVisionTransformer
from .mobilenet_v2 import MobileNetV2 from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3 from .mobilenet_v3 import MobileNetV3
@ -25,5 +26,5 @@ __all__ = [
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT' 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE'
] ]

View File

@ -0,0 +1,261 @@
# Copyright (c) OpenMMLab. All rights reserved.import math
import math
import torch
import torch.nn as nn
from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
trunc_normal_)
from mmcv.runner import ModuleList, _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer
class MAEAttention(BEiTAttention):
"""Multi-head self-attention with relative position bias used in MAE.
This module is different from ``BEiTAttention`` by initializing the
relative bias table with zeros.
"""
def init_weights(self):
"""Initialize relative position bias with zeros."""
# As MAE initializes relative position bias as zeros and this class
# inherited from BEiT which initializes relative position bias
# with `trunc_normal`, `init_weights` here does
# nothing and just passes directly
pass
class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer):
"""Implements one encoder layer in Vision Transformer.
This module is different from ``BEiTTransformerEncoderLayer`` by replacing
``BEiTAttention`` with ``MAEAttention``.
"""
def build_attn(self, attn_cfg):
self.attn = MAEAttention(**attn_cfg)
@BACKBONES.register_module()
class MAE(BEiT):
"""VisionTransformer with support for patch.
Args:
img_size (int | tuple): Input image size. Default: 224.
patch_size (int): The patch size. Default: 16.
in_channels (int): Number of input channels. Default: 3.
embed_dims (int): embedding dimension. Default: 768.
num_layers (int): depth of transformer. Default: 12.
num_heads (int): number of attention heads. Default: 12.
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
out_indices (list | tuple | int): Output from which stages.
Default: -1.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
drop_path_rate (float): stochastic depth rate. Default 0.0.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
patch_norm (bool): Whether to add a norm in PatchEmbed Block.
Default: False.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Default: False.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
pretrained (str, optional): model pretrained path. Default: None.
init_values (float): Initialize the values of Attention and FFN
with learnable scaling. Defaults to 0.1.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=-1,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
patch_norm=False,
final_norm=False,
num_fcs=2,
norm_eval=False,
pretrained=None,
init_values=0.1,
init_cfg=None):
super(MAE, self).__init__(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dims=embed_dims,
num_layers=num_layers,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
out_indices=out_indices,
qv_bias=False,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
patch_norm=patch_norm,
final_norm=final_norm,
num_fcs=num_fcs,
norm_eval=norm_eval,
pretrained=pretrained,
init_values=init_values,
init_cfg=init_cfg)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.num_patches = self.patch_shape[0] * self.patch_shape[1]
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches + 1, embed_dims))
def _build_layers(self):
dpr = [
x.item()
for x in torch.linspace(0, self.drop_path_rate, self.num_layers)
]
self.layers = ModuleList()
for i in range(self.num_layers):
self.layers.append(
MAETransformerEncoderLayer(
embed_dims=self.embed_dims,
num_heads=self.num_heads,
feedforward_channels=self.mlp_ratio * self.embed_dims,
attn_drop_rate=self.attn_drop_rate,
drop_path_rate=dpr[i],
num_fcs=self.num_fcs,
bias=True,
act_cfg=self.act_cfg,
norm_cfg=self.norm_cfg,
window_size=self.patch_shape,
init_values=self.init_values))
def fix_init_weight(self):
"""Rescale the initialization according to layer id.
This function is copied from https://github.com/microsoft/unilm/blob/master/beit/modeling_pretrain.py. # noqa: E501
Copyright (c) Microsoft Corporation
Licensed under the MIT License
"""
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.layers):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.ffn.layers[1].weight.data, layer_id + 1)
def init_weights(self):
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
self.apply(_init_weights)
self.fix_init_weight()
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
logger = get_root_logger()
checkpoint = _load_checkpoint(
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
state_dict = self.resize_rel_pos_embed(checkpoint)
state_dict = self.resize_abs_pos_embed(state_dict)
self.load_state_dict(state_dict, False)
elif self.init_cfg is not None:
super(MAE, self).init_weights()
else:
# We only implement the 'jax_impl' initialization implemented at
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
# Copyright 2019 Ross Wightman
# Licensed under the Apache License, Version 2.0 (the "License")
trunc_normal_(self.cls_token, std=.02)
for n, m in self.named_modules():
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
if 'ffn' in n:
nn.init.normal_(m.bias, mean=0., std=1e-6)
else:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
kaiming_init(m, mode='fan_in', bias=0.)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m, val=1.0, bias=0.)
def resize_abs_pos_embed(self, state_dict):
if 'pos_embed' in state_dict:
pos_embed_checkpoint = state_dict['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_extra_tokens = self.pos_embed.shape[-2] - self.num_patches
# height (== width) for the checkpoint position embedding
orig_size = int(
(pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
# height (== width) for the new position embedding
new_size = int(self.num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
embedding_size).permute(
0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode='bicubic',
align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
state_dict['pos_embed'] = new_pos_embed
return state_dict
def forward(self, inputs):
B = inputs.shape[0]
x, hw_shape = self.patch_embed(inputs)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1:
if self.final_norm:
x = self.norm1(x)
if i in self.out_indices:
out = x[:, 1:]
B, _, C = out.shape
out = out.reshape(B, hw_shape[0], hw_shape[1],
C).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return tuple(outs)

View File

@ -24,6 +24,7 @@ Import:
- configs/icnet/icnet.yml - configs/icnet/icnet.yml
- configs/isanet/isanet.yml - configs/isanet/isanet.yml
- configs/knet/knet.yml - configs/knet/knet.yml
- configs/mae/mae.yml
- configs/mobilenet_v2/mobilenet_v2.yml - configs/mobilenet_v2/mobilenet_v2.yml
- configs/mobilenet_v3/mobilenet_v3.yml - configs/mobilenet_v3/mobilenet_v3.yml
- configs/nonlocal_net/nonlocal_net.yml - configs/nonlocal_net/nonlocal_net.yml

View File

@ -0,0 +1,183 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.backbones.mae import MAE
from .utils import check_norm_state
def test_mae_backbone():
with pytest.raises(TypeError):
# pretrained must be a string path
model = MAE()
model.init_weights(pretrained=0)
with pytest.raises(TypeError):
# img_size must be int or tuple
model = MAE(img_size=512.0)
with pytest.raises(TypeError):
# out_indices must be int ,list or tuple
model = MAE(out_indices=1.)
with pytest.raises(AssertionError):
# The length of img_size tuple must be lower than 3.
MAE(img_size=(224, 224, 224))
with pytest.raises(TypeError):
# Pretrained must be None or Str.
MAE(pretrained=123)
# Test img_size isinstance tuple
imgs = torch.randn(1, 3, 224, 224)
model = MAE(img_size=(224, ))
model.init_weights()
model(imgs)
# Test img_size isinstance tuple
imgs = torch.randn(1, 3, 224, 224)
model = MAE(img_size=(224, 224))
model(imgs)
# Test norm_eval = True
model = MAE(norm_eval=True)
model.train()
# Test BEiT backbone with input size of 224 and patch size of 16
model = MAE()
model.init_weights()
model.train()
# Test out_indices = list
model = MAE(out_indices=[2, 4, 8, 12])
model.train()
assert check_norm_state(model.modules(), True)
# Test image size = (224, 224)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)
# Test MAE backbone with input size of 256 and patch size of 16
model = MAE(img_size=(256, 256))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 16, 16)
# Test MAE backbone with input size of 32 and patch size of 16
model = MAE(img_size=(32, 32))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 2, 2)
# Test unbalanced size input image
model = MAE(img_size=(112, 224))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 112, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 7, 14)
# Test irregular input image
model = MAE(img_size=(234, 345))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 234, 345)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 21)
# Test init_values=0
model = MAE(init_values=0)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)
# Test final norm
model = MAE(final_norm=True)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)
# Test patch norm
model = MAE(patch_norm=True)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)
def test_mae_init():
path = 'PATH_THAT_DO_NOT_EXIST'
# Test all combinations of pretrained and init_cfg
# pretrained=None, init_cfg=None
model = MAE(pretrained=None, init_cfg=None)
assert model.init_cfg is None
model.init_weights()
# pretrained=None
# init_cfg loads pretrain from an non-existent file
model = MAE(
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()
# test resize_rel_pos_embed
value = torch.randn(732, 16)
abs_pos_embed_value = torch.rand(1, 17, 768)
ckpt = {
'state_dict': {
'layers.0.attn.relative_position_index': 0,
'layers.0.attn.relative_position_bias_table': value,
'pos_embed': abs_pos_embed_value
}
}
model = MAE(img_size=(512, 512))
with pytest.raises(AttributeError):
model.resize_rel_pos_embed(ckpt)
# test resize abs pos embed
ckpt = model.resize_abs_pos_embed(ckpt['state_dict'])
# pretrained=None
# init_cfg=123, whose type is unsupported
model = MAE(pretrained=None, init_cfg=123)
with pytest.raises(TypeError):
model.init_weights()
# pretrained loads pretrain from an non-existent file
# init_cfg=None
model = MAE(pretrained=path, init_cfg=None)
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()
# pretrained loads pretrain from an non-existent file
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
model = MAE(
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
with pytest.raises(AssertionError):
model = MAE(pretrained=path, init_cfg=123)
# pretrain=123, whose type is unsupported
# init_cfg=None
with pytest.raises(TypeError):
model = MAE(pretrained=123, init_cfg=None)
# pretrain=123, whose type is unsupported
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
model = MAE(
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
# pretrain=123, whose type is unsupported
# init_cfg=123, whose type is unsupported
with pytest.raises(AssertionError):
model = MAE(pretrained=123, init_cfg=123)