mmpretrain/tests/test_models/test_backbones/test_mlp_mixer.py

76 lines
1.9 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import pytest
import torch
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.backbones import MlpMixer
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_mlp_mixer_backbone():
cfg_ori = dict(
arch='b',
img_size=224,
patch_size=16,
drop_rate=0.1,
init_cfg=[
dict(
type='Kaiming',
layer='Conv2d',
mode='fan_in',
nonlinearity='linear')
])
with pytest.raises(AssertionError):
# test invalid arch
cfg = deepcopy(cfg_ori)
cfg['arch'] = 'unknown'
MlpMixer(**cfg)
with pytest.raises(AssertionError):
# test arch without essential keys
cfg = deepcopy(cfg_ori)
cfg['arch'] = {
'num_layers': 24,
'tokens_mlp_dims': 384,
'channels_mlp_dims': 3072,
}
MlpMixer(**cfg)
# Test MlpMixer base model with input size of 224
# and patch size of 16
model = MlpMixer(**cfg_ori)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)
imgs = torch.randn(3, 3, 224, 224)
feat = model(imgs)[-1]
assert feat.shape == (3, 768, 196)
# Test MlpMixer with multi out indices
cfg = deepcopy(cfg_ori)
cfg['out_indices'] = [-3, -2, -1]
model = MlpMixer(**cfg)
for out in model(imgs):
assert out.shape == (3, 768, 196)