mmpretrain/tests/test_models/test_backbones/test_mixmim.py

41 lines
1.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest import TestCase
import torch
from mmcls.models.backbones import MixMIMTransformer
class TestMixMIM(TestCase):
def setUp(self):
self.cfg = dict(arch='b', drop_rate=0.0, drop_path_rate=0.1)
def test_structure(self):
# Test custom arch
cfg = deepcopy(self.cfg)
model = MixMIMTransformer(**cfg)
self.assertEqual(model.embed_dims, 128)
self.assertEqual(sum(model.depths), 24)
self.assertIsNotNone(model.absolute_pos_embed)
num_heads = [4, 8, 16, 32]
for i, layer in enumerate(model.layers):
self.assertEqual(layer.blocks[0].num_heads, num_heads[i])
self.assertEqual(layer.blocks[0].ffn.feedforward_channels,
128 * (2**i) * 4)
def test_forward(self):
imgs = torch.randn(1, 3, 224, 224)
cfg = deepcopy(self.cfg)
model = MixMIMTransformer(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
averaged_token = outs[-1]
self.assertEqual(averaged_token.shape, (1, 1024))