# Copyright (c) OpenMMLab. All rights reserved. import os import tempfile from copy import deepcopy from unittest import TestCase import torch from mmengine.runner import load_checkpoint, save_checkpoint from mmpretrain.models.backbones import RepMLPNet class TestRepMLP(TestCase): def setUp(self): # default model setting self.cfg = dict( arch='b', img_size=224, out_indices=(3, ), reparam_conv_kernels=(1, 3), final_norm=True) # default model setting and output stage channels self.model_forward_settings = [ dict(model_name='B', out_sizes=(96, 192, 384, 768)), ] # temp ckpt path self.ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth') def test_arch(self): # Test invalid arch data type with self.assertRaisesRegex(AssertionError, 'arch needs a dict'): cfg = deepcopy(self.cfg) cfg['arch'] = [96, 192, 384, 768] RepMLPNet(**cfg) # Test invalid default arch with self.assertRaisesRegex(AssertionError, 'not in default archs'): cfg = deepcopy(self.cfg) cfg['arch'] = 'A' RepMLPNet(**cfg) # Test invalid custom arch with self.assertRaisesRegex(AssertionError, 'Custom arch needs'): cfg = deepcopy(self.cfg) cfg['arch'] = { 'channels': [96, 192, 384, 768], 'depths': [2, 2, 12, 2] } RepMLPNet(**cfg) # test len(arch['depths']) equals to len(arch['channels']) # equals to len(arch['sharesets_nums']) with self.assertRaisesRegex(AssertionError, 'Length of setting'): cfg = deepcopy(self.cfg) cfg['arch'] = { 'channels': [96, 192, 384, 768], 'depths': [2, 2, 12, 2], 'sharesets_nums': [1, 4, 32] } RepMLPNet(**cfg) # Test custom arch cfg = deepcopy(self.cfg) channels = [96, 192, 384, 768] depths = [2, 2, 12, 2] sharesets_nums = [1, 4, 32, 128] cfg['arch'] = { 'channels': channels, 'depths': depths, 'sharesets_nums': sharesets_nums } cfg['out_indices'] = (0, 1, 2, 3) model = RepMLPNet(**cfg) for i, stage in enumerate(model.stages): self.assertEqual(len(stage), depths[i]) self.assertEqual(stage[0].repmlp_block.channels, channels[i]) self.assertEqual(stage[0].repmlp_block.deploy, False) self.assertEqual(stage[0].repmlp_block.num_sharesets, sharesets_nums[i]) def test_init(self): # test weight init cfg cfg = deepcopy(self.cfg) cfg['init_cfg'] = [ dict( type='Kaiming', layer='Conv2d', mode='fan_in', nonlinearity='linear') ] model = RepMLPNet(**cfg) ori_weight = model.patch_embed.projection.weight.clone().detach() model.init_weights() initialized_weight = model.patch_embed.projection.weight self.assertFalse(torch.allclose(ori_weight, initialized_weight)) def test_forward(self): imgs = torch.randn(1, 3, 224, 224) cfg = deepcopy(self.cfg) model = RepMLPNet(**cfg) feat = model(imgs) self.assertTrue(isinstance(feat, tuple)) self.assertEqual(len(feat), 1) self.assertTrue(isinstance(feat[0], torch.Tensor)) self.assertEqual(feat[0].shape, torch.Size((1, 768, 7, 7))) imgs = torch.randn(1, 3, 256, 256) with self.assertRaisesRegex(AssertionError, "doesn't support dynamic"): model(imgs) # Test RepMLPNet model forward for model_test_setting in self.model_forward_settings: model = RepMLPNet( model_test_setting['model_name'], out_indices=(0, 1, 2, 3), final_norm=False) model.init_weights() model.train() imgs = torch.randn(1, 3, 224, 224) feat = model(imgs) self.assertEqual( feat[0].shape, torch.Size((1, model_test_setting['out_sizes'][1], 28, 28))) self.assertEqual( feat[1].shape, torch.Size((1, model_test_setting['out_sizes'][2], 14, 14))) self.assertEqual( feat[2].shape, torch.Size((1, model_test_setting['out_sizes'][3], 7, 7))) self.assertEqual( feat[3].shape, torch.Size((1, model_test_setting['out_sizes'][3], 7, 7))) def test_deploy_(self): # Test output before and load from deploy checkpoint imgs = torch.randn((1, 3, 224, 224)) cfg = dict( arch='b', out_indices=( 1, 3, ), reparam_conv_kernels=(1, 3, 5)) model = RepMLPNet(**cfg) model.eval() feats = model(imgs) model.switch_to_deploy() for m in model.modules(): if hasattr(m, 'deploy'): self.assertTrue(m.deploy) model.eval() feats_ = model(imgs) assert len(feats) == len(feats_) for i in range(len(feats)): self.assertTrue( torch.allclose( feats[i].sum(), feats_[i].sum(), rtol=0.1, atol=0.1)) cfg['deploy'] = True model_deploy = RepMLPNet(**cfg) model_deploy.eval() save_checkpoint(model.state_dict(), self.ckpt_path) load_checkpoint(model_deploy, self.ckpt_path, strict=True) feats__ = model_deploy(imgs) assert len(feats_) == len(feats__) for i in range(len(feats)): self.assertTrue( torch.allclose(feats__[i], feats_[i], rtol=0.01, atol=0.01))