# Copyright (c) OpenMMLab. All rights reserved. from copy import deepcopy from unittest import TestCase import torch from mmpretrain.models.backbones import BEiTViT class TestBEiT(TestCase): def setUp(self): self.cfg = dict( arch='b', img_size=224, patch_size=16, drop_path_rate=0.1) def test_structure(self): # Test invalid default arch with self.assertRaisesRegex(AssertionError, 'not in default archs'): cfg = deepcopy(self.cfg) cfg['arch'] = 'unknown' BEiTViT(**cfg) # Test invalid custom arch with self.assertRaisesRegex(AssertionError, 'Custom arch needs'): cfg = deepcopy(self.cfg) cfg['arch'] = { 'num_layers': 24, 'num_heads': 16, 'feedforward_channels': 4096 } BEiTViT(**cfg) # Test custom arch cfg = deepcopy(self.cfg) cfg['arch'] = { 'embed_dims': 128, 'num_layers': 24, 'num_heads': 16, 'feedforward_channels': 1024 } model = BEiTViT(**cfg) self.assertEqual(model.embed_dims, 128) self.assertEqual(model.num_layers, 24) self.assertIsNone(model.pos_embed) self.assertIsNone(model.rel_pos_bias) for layer in model.layers: self.assertEqual(layer.attn.num_heads, 16) self.assertEqual(layer.ffn.feedforward_channels, 1024) # Test out_indices cfg = deepcopy(self.cfg) cfg['out_indices'] = {1: 1} with self.assertRaisesRegex(AssertionError, "get "): BEiTViT(**cfg) cfg['out_indices'] = [0, 13] with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'): BEiTViT(**cfg) # Test pos_embed cfg = deepcopy(self.cfg) cfg['use_abs_pos_emb'] = True model = BEiTViT(**cfg) self.assertEqual(model.pos_embed.shape, (1, 197, 768)) # Test model structure cfg = deepcopy(self.cfg) cfg['drop_path_rate'] = 0.1 model = BEiTViT(**cfg) self.assertEqual(len(model.layers), 12) dpr_inc = 0.1 / (12 - 1) dpr = 0 for layer in model.layers: self.assertEqual(layer.gamma_1.shape, (768, )) self.assertEqual(layer.gamma_2.shape, (768, )) self.assertEqual(layer.attn.embed_dims, 768) self.assertEqual(layer.attn.num_heads, 12) self.assertEqual(layer.ffn.feedforward_channels, 3072) self.assertFalse(layer.ffn.add_identity) self.assertAlmostEqual(layer.ffn.dropout_layer.drop_prob, dpr) dpr += dpr_inc def test_forward(self): imgs = torch.randn(1, 3, 224, 224) cfg = deepcopy(self.cfg) cfg['out_type'] = 'cls_token' model = BEiTViT(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) cls_token = outs[-1] self.assertEqual(cls_token.shape, (1, 768)) # test without output cls_token cfg = deepcopy(self.cfg) model = BEiTViT(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) patch_token = outs[-1] self.assertEqual(patch_token.shape, (1, 768)) # test without average cfg = deepcopy(self.cfg) cfg['out_type'] = 'featmap' model = BEiTViT(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) patch_token = outs[-1] self.assertEqual(patch_token.shape, (1, 768, 14, 14)) # Test forward with multi out indices cfg = deepcopy(self.cfg) cfg['out_indices'] = [-3, -2, -1] model = BEiTViT(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 3) for out in outs: patch_token = out self.assertEqual(patch_token.shape, (1, 768))