mirror of https://github.com/alibaba/EasyCV.git
36 lines
895 B
Python
36 lines
895 B
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from easycv.models.backbones import ViTDet
|
|
|
|
|
|
class ViTDetTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
def test_vitdet(self):
|
|
model = ViTDet(
|
|
img_size=1024,
|
|
embed_dim=768,
|
|
depth=12,
|
|
num_heads=12,
|
|
mlp_ratio=4,
|
|
qkv_bias=True,
|
|
qk_scale=None,
|
|
drop_rate=0.,
|
|
attn_drop_rate=0.,
|
|
drop_path_rate=0.1,
|
|
use_abs_pos_emb=True,
|
|
aggregation='attn',
|
|
)
|
|
|
|
model.init_weights()
|
|
model.train()
|
|
imgs = torch.randn(2, 3, 1024, 1024)
|
|
feat = model(imgs)
|
|
self.assertEqual(len(feat), 1)
|
|
self.assertEqual(feat[0].shape, torch.Size([2, 768, 64, 64]))
|