# Copyright (c) OpenMMLab. All rights reserved. import math import os import tempfile from copy import deepcopy from unittest import TestCase import torch from mmengine.runner import load_checkpoint, save_checkpoint from mmcls.models.backbones import DeiT3 class TestDeiT3(TestCase): def setUp(self): self.cfg = dict( arch='b', img_size=224, patch_size=16, drop_path_rate=0.1) def test_structure(self): # Test invalid default arch with self.assertRaisesRegex(AssertionError, 'not in default archs'): cfg = deepcopy(self.cfg) cfg['arch'] = 'unknown' DeiT3(**cfg) # Test invalid custom arch with self.assertRaisesRegex(AssertionError, 'Custom arch needs'): cfg = deepcopy(self.cfg) cfg['arch'] = { 'num_layers': 24, 'num_heads': 16, 'feedforward_channels': 4096 } DeiT3(**cfg) # Test custom arch cfg = deepcopy(self.cfg) cfg['arch'] = { 'embed_dims': 128, 'num_layers': 24, 'num_heads': 16, 'feedforward_channels': 1024 } model = DeiT3(**cfg) self.assertEqual(model.embed_dims, 128) self.assertEqual(model.num_layers, 24) for layer in model.layers: self.assertEqual(layer.attn.num_heads, 16) self.assertEqual(layer.ffn.feedforward_channels, 1024) # Test out_indices cfg = deepcopy(self.cfg) cfg['out_indices'] = {1: 1} with self.assertRaisesRegex(AssertionError, "get "): DeiT3(**cfg) cfg['out_indices'] = [0, 13] with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'): DeiT3(**cfg) # Test model structure cfg = deepcopy(self.cfg) model = DeiT3(**cfg) self.assertEqual(len(model.layers), 12) dpr_inc = 0.1 / (12 - 1) dpr = 0 for layer in model.layers: self.assertEqual(layer.attn.embed_dims, 768) self.assertEqual(layer.attn.num_heads, 12) self.assertEqual(layer.ffn.feedforward_channels, 3072) self.assertAlmostEqual(layer.attn.out_drop.drop_prob, dpr) self.assertAlmostEqual(layer.ffn.dropout_layer.drop_prob, dpr) dpr += dpr_inc def test_init_weights(self): # test weight init cfg cfg = deepcopy(self.cfg) cfg['init_cfg'] = [ dict( type='Kaiming', layer='Conv2d', mode='fan_in', nonlinearity='linear') ] model = DeiT3(**cfg) ori_weight = model.patch_embed.projection.weight.clone().detach() # The pos_embed is all zero before initialize self.assertTrue(torch.allclose(model.pos_embed, torch.tensor(0.))) model.init_weights() initialized_weight = model.patch_embed.projection.weight self.assertFalse(torch.allclose(ori_weight, initialized_weight)) self.assertFalse(torch.allclose(model.pos_embed, torch.tensor(0.))) # test load checkpoint pretrain_pos_embed = model.pos_embed.clone().detach() tmpdir = tempfile.gettempdir() checkpoint = os.path.join(tmpdir, 'test.pth') save_checkpoint(model.state_dict(), checkpoint) cfg = deepcopy(self.cfg) model = DeiT3(**cfg) load_checkpoint(model, checkpoint, strict=True) self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed)) # test load checkpoint with different img_size cfg = deepcopy(self.cfg) cfg['img_size'] = 384 model = DeiT3(**cfg) load_checkpoint(model, checkpoint, strict=True) os.remove(checkpoint) def test_forward(self): imgs = torch.randn(1, 3, 224, 224) # test with_cls_token=False cfg = deepcopy(self.cfg) cfg['with_cls_token'] = False cfg['output_cls_token'] = True with self.assertRaisesRegex(AssertionError, 'but got False'): DeiT3(**cfg) cfg = deepcopy(self.cfg) cfg['with_cls_token'] = False cfg['output_cls_token'] = False model = DeiT3(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) patch_token = outs[-1] self.assertEqual(patch_token.shape, (1, 768, 14, 14)) # test with output_cls_token cfg = deepcopy(self.cfg) model = DeiT3(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) patch_token, cls_token = outs[-1] self.assertEqual(patch_token.shape, (1, 768, 14, 14)) self.assertEqual(cls_token.shape, (1, 768)) # test without output_cls_token cfg = deepcopy(self.cfg) cfg['output_cls_token'] = False model = DeiT3(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) patch_token = outs[-1] self.assertEqual(patch_token.shape, (1, 768, 14, 14)) # Test forward with multi out indices cfg = deepcopy(self.cfg) cfg['out_indices'] = [-3, -2, -1] model = DeiT3(**cfg) outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 3) for out in outs: patch_token, cls_token = out self.assertEqual(patch_token.shape, (1, 768, 14, 14)) self.assertEqual(cls_token.shape, (1, 768)) # Test forward with dynamic input size imgs1 = torch.randn(1, 3, 224, 224) imgs2 = torch.randn(1, 3, 256, 256) imgs3 = torch.randn(1, 3, 256, 309) cfg = deepcopy(self.cfg) model = DeiT3(**cfg) for imgs in [imgs1, imgs2, imgs3]: outs = model(imgs) self.assertIsInstance(outs, tuple) self.assertEqual(len(outs), 1) patch_token, cls_token = outs[-1] expect_feat_shape = (math.ceil(imgs.shape[2] / 16), math.ceil(imgs.shape[3] / 16)) self.assertEqual(patch_token.shape, (1, 768, *expect_feat_shape)) self.assertEqual(cls_token.shape, (1, 768))