41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from copy import deepcopy
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
|
|
from mmcls.models.backbones import MixMIMTransformer
|
|
|
|
|
|
class TestMixMIM(TestCase):
|
|
|
|
def setUp(self):
|
|
self.cfg = dict(arch='b', drop_rate=0.0, drop_path_rate=0.1)
|
|
|
|
def test_structure(self):
|
|
|
|
# Test custom arch
|
|
cfg = deepcopy(self.cfg)
|
|
|
|
model = MixMIMTransformer(**cfg)
|
|
self.assertEqual(model.embed_dims, 128)
|
|
self.assertEqual(sum(model.depths), 24)
|
|
self.assertIsNotNone(model.absolute_pos_embed)
|
|
|
|
num_heads = [4, 8, 16, 32]
|
|
for i, layer in enumerate(model.layers):
|
|
self.assertEqual(layer.blocks[0].num_heads, num_heads[i])
|
|
self.assertEqual(layer.blocks[0].ffn.feedforward_channels,
|
|
128 * (2**i) * 4)
|
|
|
|
def test_forward(self):
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
|
|
cfg = deepcopy(self.cfg)
|
|
model = MixMIMTransformer(**cfg)
|
|
outs = model(imgs)
|
|
self.assertIsInstance(outs, tuple)
|
|
self.assertEqual(len(outs), 1)
|
|
averaged_token = outs[-1]
|
|
self.assertEqual(averaged_token.shape, (1, 1024))
|