mirror of https://github.com/alibaba/EasyCV.git
54 lines
1.9 KiB
Python
54 lines
1.9 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from easycv.models.detection.detectors.yolox.yolox import YOLOX
|
|
|
|
|
|
class YOLOXTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
def test_yolox(self):
|
|
for model_type in ['s', 'm', 'l', 'x', 'tiny', 'nano']:
|
|
model = YOLOX(
|
|
test_conf=0.01,
|
|
nms_thre=0.65,
|
|
backbone='CSPDarknet',
|
|
model_type=model_type,
|
|
head=dict(
|
|
type='YOLOXHead', model_type=model_type, num_classes=2),
|
|
)
|
|
model = model.cuda()
|
|
model.train()
|
|
|
|
batch_size = 2
|
|
imgs = torch.randn(batch_size, 3, 640, 640).cuda()
|
|
num_boxes = 5
|
|
gt_bboxes = torch.randint(
|
|
0, 600, size=(batch_size, num_boxes, 4)).cuda()
|
|
gt_labels = torch.randint(
|
|
0, 1, size=(batch_size, num_boxes, 1)).cuda()
|
|
img_metas = [{'img_shape': (640, 640, 3)}] * batch_size
|
|
kwargs = {
|
|
'gt_bboxes': gt_bboxes,
|
|
'gt_labels': gt_labels,
|
|
'img_metas': img_metas
|
|
}
|
|
output = model(imgs, mode='train', **kwargs)
|
|
self.assertEqual(output['img_h'].cpu().numpy(),
|
|
np.array(640, dtype=np.float))
|
|
self.assertEqual(output['img_w'].cpu().numpy(),
|
|
np.array(640, dtype=np.float))
|
|
self.assertEqual(output['total_loss'].shape, torch.Size([]))
|
|
self.assertEqual(output['iou_l'].shape, torch.Size([]))
|
|
self.assertEqual(output['conf_l'].shape, torch.Size([]))
|
|
self.assertEqual(output['cls_l'].shape, torch.Size([]))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|