mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
* [Fix]: Fix lint * [WIP]: Add mae seg config * [Feature]: Add MAE seg * [Fix]: Fix mae dataset img scale bug * [Fix]: Fix lint * [Feature]: Change mae config to mae_segmentation's config * [Feature]: Add interpolate pe when loading * [Fix]: Fix pos_embed not used bug * [Fix]: Fix lint * [Fix]: Init rel pos embed with zeros * [Fix]: Fix lint * [Fix]: Change the type name of backbone to MAE * [Fix]: Delete ade20k_512x512.py * [Fix]: Use mmseg provided ade20k.py * [Fix]: Change 1 sample per gpu to 2 samples per gpu * [Fix]: Fix conflict * [Refactor]: Use the TransformerEncoderLayer of BEiT * [Feature]: Add UT * [Fix]: Change the default value of qv bias to False * [Fix]: Initialize relative pos table with zeros * [Fix]: Delete redundant code in mae * [Fix]: Fix lint * [Fix]: Rename qkv_bias to qv_bias * [Fix]: Add docstring to weight_init of MAEAttention * [Refactor]: Delete qv_bias param * [Fix]: Add reference to fix_init_weight * [Fix]: Fix lint * [Fix]: Delete extra crop size * [Refactor]: Rename mae * [Fix]: Set bias to True * [Fix]: Delete redundant params * [Fix]: Fix lint * [Fix]: Fix UT * [Fix]: Add resize abs pos embed * [Fix]: Fix UT * [Refactor]: Use build layer * [Fix]: Add licsense and fix docstring * [Fix]: Fix docstring * [Feature]: Add README metafile * [Fix]: Change 640 to 512 * [Fix]: Fix README * fix readme of MAE Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
184 lines
5.5 KiB
Python
184 lines
5.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmseg.models.backbones.mae import MAE
|
|
from .utils import check_norm_state
|
|
|
|
|
|
def test_mae_backbone():
|
|
with pytest.raises(TypeError):
|
|
# pretrained must be a string path
|
|
model = MAE()
|
|
model.init_weights(pretrained=0)
|
|
|
|
with pytest.raises(TypeError):
|
|
# img_size must be int or tuple
|
|
model = MAE(img_size=512.0)
|
|
|
|
with pytest.raises(TypeError):
|
|
# out_indices must be int ,list or tuple
|
|
model = MAE(out_indices=1.)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# The length of img_size tuple must be lower than 3.
|
|
MAE(img_size=(224, 224, 224))
|
|
|
|
with pytest.raises(TypeError):
|
|
# Pretrained must be None or Str.
|
|
MAE(pretrained=123)
|
|
|
|
# Test img_size isinstance tuple
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
model = MAE(img_size=(224, ))
|
|
model.init_weights()
|
|
model(imgs)
|
|
|
|
# Test img_size isinstance tuple
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
model = MAE(img_size=(224, 224))
|
|
model(imgs)
|
|
|
|
# Test norm_eval = True
|
|
model = MAE(norm_eval=True)
|
|
model.train()
|
|
|
|
# Test BEiT backbone with input size of 224 and patch size of 16
|
|
model = MAE()
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
# Test out_indices = list
|
|
model = MAE(out_indices=[2, 4, 8, 12])
|
|
model.train()
|
|
|
|
assert check_norm_state(model.modules(), True)
|
|
|
|
# Test image size = (224, 224)
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
feat = model(imgs)
|
|
assert feat[-1].shape == (1, 768, 14, 14)
|
|
|
|
# Test MAE backbone with input size of 256 and patch size of 16
|
|
model = MAE(img_size=(256, 256))
|
|
model.init_weights()
|
|
model.train()
|
|
imgs = torch.randn(1, 3, 256, 256)
|
|
feat = model(imgs)
|
|
assert feat[-1].shape == (1, 768, 16, 16)
|
|
|
|
# Test MAE backbone with input size of 32 and patch size of 16
|
|
model = MAE(img_size=(32, 32))
|
|
model.init_weights()
|
|
model.train()
|
|
imgs = torch.randn(1, 3, 32, 32)
|
|
feat = model(imgs)
|
|
assert feat[-1].shape == (1, 768, 2, 2)
|
|
|
|
# Test unbalanced size input image
|
|
model = MAE(img_size=(112, 224))
|
|
model.init_weights()
|
|
model.train()
|
|
imgs = torch.randn(1, 3, 112, 224)
|
|
feat = model(imgs)
|
|
assert feat[-1].shape == (1, 768, 7, 14)
|
|
|
|
# Test irregular input image
|
|
model = MAE(img_size=(234, 345))
|
|
model.init_weights()
|
|
model.train()
|
|
imgs = torch.randn(1, 3, 234, 345)
|
|
feat = model(imgs)
|
|
assert feat[-1].shape == (1, 768, 14, 21)
|
|
|
|
# Test init_values=0
|
|
model = MAE(init_values=0)
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
feat = model(imgs)
|
|
assert feat[-1].shape == (1, 768, 14, 14)
|
|
|
|
# Test final norm
|
|
model = MAE(final_norm=True)
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
feat = model(imgs)
|
|
assert feat[-1].shape == (1, 768, 14, 14)
|
|
|
|
# Test patch norm
|
|
model = MAE(patch_norm=True)
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
feat = model(imgs)
|
|
assert feat[-1].shape == (1, 768, 14, 14)
|
|
|
|
|
|
def test_mae_init():
|
|
path = 'PATH_THAT_DO_NOT_EXIST'
|
|
# Test all combinations of pretrained and init_cfg
|
|
# pretrained=None, init_cfg=None
|
|
model = MAE(pretrained=None, init_cfg=None)
|
|
assert model.init_cfg is None
|
|
model.init_weights()
|
|
|
|
# pretrained=None
|
|
# init_cfg loads pretrain from an non-existent file
|
|
model = MAE(
|
|
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
|
|
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
|
# Test loading a checkpoint from an non-existent file
|
|
with pytest.raises(OSError):
|
|
model.init_weights()
|
|
|
|
# test resize_rel_pos_embed
|
|
value = torch.randn(732, 16)
|
|
abs_pos_embed_value = torch.rand(1, 17, 768)
|
|
ckpt = {
|
|
'state_dict': {
|
|
'layers.0.attn.relative_position_index': 0,
|
|
'layers.0.attn.relative_position_bias_table': value,
|
|
'pos_embed': abs_pos_embed_value
|
|
}
|
|
}
|
|
model = MAE(img_size=(512, 512))
|
|
with pytest.raises(AttributeError):
|
|
model.resize_rel_pos_embed(ckpt)
|
|
|
|
# test resize abs pos embed
|
|
ckpt = model.resize_abs_pos_embed(ckpt['state_dict'])
|
|
|
|
# pretrained=None
|
|
# init_cfg=123, whose type is unsupported
|
|
model = MAE(pretrained=None, init_cfg=123)
|
|
with pytest.raises(TypeError):
|
|
model.init_weights()
|
|
|
|
# pretrained loads pretrain from an non-existent file
|
|
# init_cfg=None
|
|
model = MAE(pretrained=path, init_cfg=None)
|
|
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
|
# Test loading a checkpoint from an non-existent file
|
|
with pytest.raises(OSError):
|
|
model.init_weights()
|
|
|
|
# pretrained loads pretrain from an non-existent file
|
|
# init_cfg loads pretrain from an non-existent file
|
|
with pytest.raises(AssertionError):
|
|
model = MAE(
|
|
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
|
|
with pytest.raises(AssertionError):
|
|
model = MAE(pretrained=path, init_cfg=123)
|
|
|
|
# pretrain=123, whose type is unsupported
|
|
# init_cfg=None
|
|
with pytest.raises(TypeError):
|
|
model = MAE(pretrained=123, init_cfg=None)
|
|
|
|
# pretrain=123, whose type is unsupported
|
|
# init_cfg loads pretrain from an non-existent file
|
|
with pytest.raises(AssertionError):
|
|
model = MAE(
|
|
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
|
|
|
|
# pretrain=123, whose type is unsupported
|
|
# init_cfg=123, whose type is unsupported
|
|
with pytest.raises(AssertionError):
|
|
model = MAE(pretrained=123, init_cfg=123)
|