mmpretrain/tests/test_models/test_backbones/test_beit.py

125 lines
4.1 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from unittest import TestCase
import torch
from mmpretrain.models.backbones import BEiTViT
class TestBEiT(TestCase):
def setUp(self):
self.cfg = dict(
arch='b', img_size=224, patch_size=16, drop_path_rate=0.1)
def test_structure(self):
# Test invalid default arch
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = 'unknown'
BEiTViT(**cfg)
# Test invalid custom arch
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': 4096
}
BEiTViT(**cfg)
# Test custom arch
cfg = deepcopy(self.cfg)
cfg['arch'] = {
'embed_dims': 128,
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': 1024
}
model = BEiTViT(**cfg)
self.assertEqual(model.embed_dims, 128)
self.assertEqual(model.num_layers, 24)
self.assertIsNone(model.pos_embed)
self.assertIsNone(model.rel_pos_bias)
for layer in model.layers:
self.assertEqual(layer.attn.num_heads, 16)
self.assertEqual(layer.ffn.feedforward_channels, 1024)
# Test out_indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = {1: 1}
with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
BEiTViT(**cfg)
cfg['out_indices'] = [0, 13]
with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'):
BEiTViT(**cfg)
# Test pos_embed
cfg = deepcopy(self.cfg)
cfg['use_abs_pos_emb'] = True
model = BEiTViT(**cfg)
self.assertEqual(model.pos_embed.shape, (1, 197, 768))
# Test model structure
cfg = deepcopy(self.cfg)
cfg['drop_path_rate'] = 0.1
model = BEiTViT(**cfg)
self.assertEqual(len(model.layers), 12)
dpr_inc = 0.1 / (12 - 1)
dpr = 0
for layer in model.layers:
self.assertEqual(layer.gamma_1.shape, (768, ))
self.assertEqual(layer.gamma_2.shape, (768, ))
self.assertEqual(layer.attn.embed_dims, 768)
self.assertEqual(layer.attn.num_heads, 12)
self.assertEqual(layer.ffn.feedforward_channels, 3072)
self.assertFalse(layer.ffn.add_identity)
self.assertAlmostEqual(layer.ffn.dropout_layer.drop_prob, dpr)
dpr += dpr_inc
def test_forward(self):
imgs = torch.randn(1, 3, 224, 224)
# test with output_cls_token
cfg = deepcopy(self.cfg)
cfg['output_cls_token'] = True
model = BEiTViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token, cls_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 768))
self.assertEqual(cls_token.shape, (1, 768))
# test without output_cls_token
cfg = deepcopy(self.cfg)
model = BEiTViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 768))
# test without average
cfg = deepcopy(self.cfg)
cfg['avg_token'] = False
model = BEiTViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
patch_token = outs[-1]
self.assertEqual(patch_token.shape, (1, 768, 14, 14))
# Test forward with multi out indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = [-3, -2, -1]
model = BEiTViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 3)
for out in outs:
patch_token = out
self.assertEqual(patch_token.shape, (1, 768))