# Copyright (c) Alibaba, Inc. and its affiliates. import unittest import torch from easycv.models.backbones import ViTDet class ViTDetTest(unittest.TestCase): def setUp(self): print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) def test_vitdet(self): model = ViTDet( img_size=1024, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, use_abs_pos_emb=True, aggregation='attn', ) model.init_weights() model.train() imgs = torch.randn(2, 3, 1024, 1024) feat = model(imgs) self.assertEqual(len(feat), 1) self.assertEqual(feat[0].shape, torch.Size([2, 768, 64, 64]))