EasyCV/tests/models/backbones/test_vitdet.py

36 lines
895 B
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import torch
from easycv.models.backbones import ViTDet
class ViTDetTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_vitdet(self):
model = ViTDet(
img_size=1024,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
use_abs_pos_emb=True,
aggregation='attn',
)
model.init_weights()
model.train()
imgs = torch.randn(2, 3, 1024, 1024)
feat = model(imgs)
self.assertEqual(len(feat), 1)
self.assertEqual(feat[0].shape, torch.Size([2, 768, 64, 64]))