112 lines
4.0 KiB
Python
112 lines
4.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import math
|
|
import os
|
|
import tempfile
|
|
from copy import deepcopy
|
|
from unittest import TestCase
|
|
|
|
import torch
|
|
from mmengine.runner import load_checkpoint, save_checkpoint
|
|
|
|
from mmpretrain.models.backbones import DistilledVisionTransformer
|
|
from .utils import timm_resize_pos_embed
|
|
|
|
|
|
class TestDeiT(TestCase):
|
|
|
|
def setUp(self):
|
|
self.cfg = dict(
|
|
arch='deit-tiny', img_size=224, patch_size=16, drop_rate=0.1)
|
|
|
|
def test_init_weights(self):
|
|
# test weight init cfg
|
|
cfg = deepcopy(self.cfg)
|
|
cfg['init_cfg'] = [
|
|
dict(
|
|
type='Kaiming',
|
|
layer='Conv2d',
|
|
mode='fan_in',
|
|
nonlinearity='linear')
|
|
]
|
|
model = DistilledVisionTransformer(**cfg)
|
|
ori_weight = model.patch_embed.projection.weight.clone().detach()
|
|
# The pos_embed is all zero before initialize
|
|
self.assertTrue(torch.allclose(model.dist_token, torch.tensor(0.)))
|
|
|
|
model.init_weights()
|
|
initialized_weight = model.patch_embed.projection.weight
|
|
self.assertFalse(torch.allclose(ori_weight, initialized_weight))
|
|
self.assertFalse(torch.allclose(model.dist_token, torch.tensor(0.)))
|
|
|
|
# test load checkpoint
|
|
pretrain_pos_embed = model.pos_embed.clone().detach()
|
|
tmpdir = tempfile.gettempdir()
|
|
checkpoint = os.path.join(tmpdir, 'test.pth')
|
|
save_checkpoint(model.state_dict(), checkpoint)
|
|
cfg = deepcopy(self.cfg)
|
|
model = DistilledVisionTransformer(**cfg)
|
|
load_checkpoint(model, checkpoint, strict=True)
|
|
self.assertTrue(torch.allclose(model.pos_embed, pretrain_pos_embed))
|
|
|
|
# test load checkpoint with different img_size
|
|
cfg = deepcopy(self.cfg)
|
|
cfg['img_size'] = 384
|
|
model = DistilledVisionTransformer(**cfg)
|
|
load_checkpoint(model, checkpoint, strict=True)
|
|
resized_pos_embed = timm_resize_pos_embed(
|
|
pretrain_pos_embed, model.pos_embed, num_tokens=2)
|
|
self.assertTrue(torch.allclose(model.pos_embed, resized_pos_embed))
|
|
|
|
os.remove(checkpoint)
|
|
|
|
def test_forward(self):
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
|
|
# test with output cls_token
|
|
cfg = deepcopy(self.cfg)
|
|
model = DistilledVisionTransformer(**cfg)
|
|
outs = model(imgs)
|
|
self.assertIsInstance(outs, tuple)
|
|
self.assertEqual(len(outs), 1)
|
|
cls_token, dist_token = outs[-1]
|
|
self.assertEqual(cls_token.shape, (1, 192))
|
|
self.assertEqual(dist_token.shape, (1, 192))
|
|
|
|
# test without output cls_token
|
|
cfg = deepcopy(self.cfg)
|
|
cfg['out_type'] = 'featmap'
|
|
model = DistilledVisionTransformer(**cfg)
|
|
outs = model(imgs)
|
|
self.assertIsInstance(outs, tuple)
|
|
self.assertEqual(len(outs), 1)
|
|
patch_token = outs[-1]
|
|
self.assertEqual(patch_token.shape, (1, 192, 14, 14))
|
|
|
|
# Test forward with multi out indices
|
|
cfg = deepcopy(self.cfg)
|
|
cfg['out_indices'] = [-3, -2, -1]
|
|
model = DistilledVisionTransformer(**cfg)
|
|
outs = model(imgs)
|
|
self.assertIsInstance(outs, tuple)
|
|
self.assertEqual(len(outs), 3)
|
|
for out in outs:
|
|
cls_token, dist_token = out
|
|
self.assertEqual(cls_token.shape, (1, 192))
|
|
self.assertEqual(dist_token.shape, (1, 192))
|
|
|
|
# Test forward with dynamic input size
|
|
imgs1 = torch.randn(1, 3, 224, 224)
|
|
imgs2 = torch.randn(1, 3, 256, 256)
|
|
imgs3 = torch.randn(1, 3, 256, 309)
|
|
cfg = deepcopy(self.cfg)
|
|
cfg['out_type'] = 'featmap'
|
|
model = DistilledVisionTransformer(**cfg)
|
|
for imgs in [imgs1, imgs2, imgs3]:
|
|
outs = model(imgs)
|
|
self.assertIsInstance(outs, tuple)
|
|
self.assertEqual(len(outs), 1)
|
|
featmap = outs[-1]
|
|
expect_feat_shape = (math.ceil(imgs.shape[2] / 16),
|
|
math.ceil(imgs.shape[3] / 16))
|
|
self.assertEqual(featmap.shape, (1, 192, *expect_feat_shape))
|